Tavernari's picture
Upload folder using huggingface_hub
148b631 verified
"""
metrics.py - Evaluation metrics for code completion validation.
Implement functions to calculate bracket accuracy, indentation,
and other code-specific metrics.
"""
import re
from typing import List, Tuple, Dict
from dataclasses import dataclass
from collections import Counter
@dataclass
class TestResult:
"""Individual test result."""
test_name: str
category: str
passed: bool
prompt: str
generated: str
expected_patterns: List[str]
matched_patterns: List[str]
failed_patterns: List[str]
forbidden_matches: List[str]
score: float # 0.0 to 1.0
@dataclass
class CategoryResult:
"""Aggregated result for a category."""
category: str
total_tests: int
passed_tests: int
accuracy: float
test_results: List[TestResult]
@dataclass
class ValidationReport:
"""Complete validation report."""
model_name: str
total_tests: int
total_passed: int
overall_accuracy: float
category_results: Dict[str, CategoryResult]
bracket_accuracy: float
indentation_accuracy: float
structure_accuracy: float
def check_brackets_balanced(text: str) -> Tuple[bool, str]:
"""
Checks if brackets are balanced.
Returns:
(is_balanced, error_message)
"""
stack = []
pairs = {'(': ')', '[': ']', '{': '}'}
for i, char in enumerate(text):
if char in pairs:
stack.append((char, i))
elif char in pairs.values():
if not stack:
return False, f"Extra bracket '{char}' at position {i}"
opening, pos = stack.pop()
if pairs[opening] != char:
return False, f"Mismatch: '{opening}' at position {pos} closed with '{char}' at position {i}"
if stack:
unclosed = [(char, pos) for char, pos in stack]
return False, f"Unclosed brackets: {unclosed}"
return True, "OK"
def count_bracket_errors(prompt: str, generated: str) -> Dict[str, int]:
"""
Counts bracket errors in generated code.
Returns:
Dictionary with error counts by type
"""
full_code = prompt + generated
errors = {
'unclosed_parens': 0,
'unclosed_brackets': 0,
'unclosed_braces': 0,
'extra_closing': 0
}
# Count open and close
parens = full_code.count('(') - full_code.count(')')
brackets = full_code.count('[') - full_code.count(']')
braces = full_code.count('{') - full_code.count('}')
if parens > 0:
errors['unclosed_parens'] = parens
elif parens < 0:
errors['extra_closing'] += abs(parens)
if brackets > 0:
errors['unclosed_brackets'] = brackets
elif brackets < 0:
errors['extra_closing'] += abs(brackets)
if braces > 0:
errors['unclosed_braces'] = braces
elif braces < 0:
errors['extra_closing'] += abs(braces)
return errors
def check_indentation(text: str) -> Dict[str, any]:
"""
Analyzes indentation quality in code.
Returns:
Dictionary with indentation metrics
"""
lines = text.split('\n')
stats = {
'total_lines': len(lines),
'indented_lines': 0,
'consistent_indent': True,
'indent_style': None, # 'spaces' or 'tabs'
'indent_size': None,
'indent_errors': []
}
indent_sizes = []
for i, line in enumerate(lines):
if not line.strip(): # Empty line
continue
# Count leading whitespace
stripped = line.lstrip()
indent = len(line) - len(stripped)
if indent > 0:
stats['indented_lines'] += 1
# Detect style
if line.startswith('\t'):
if stats['indent_style'] is None:
stats['indent_style'] = 'tabs'
elif stats['indent_style'] == 'spaces':
stats['consistent_indent'] = False
else:
if stats['indent_style'] is None:
stats['indent_style'] = 'spaces'
elif stats['indent_style'] == 'tabs':
stats['consistent_indent'] = False
if stats['indent_style'] == 'spaces':
indent_sizes.append(indent)
# Determine most common indent size
if indent_sizes:
# Find GCD of indent sizes
common_indents = Counter(indent_sizes)
stats['indent_size'] = min(common_indents.keys()) if common_indents else 4
return stats
def evaluate_test_case(
prompt: str,
generated: str,
expected_patterns: List[str],
forbidden_patterns: List[str] = None
) -> Tuple[bool, float, List[str], List[str], List[str]]:
"""
Evaluates a test case.
Returns:
(passed, score, matched_patterns, failed_patterns, forbidden_matches)
"""
if forbidden_patterns is None:
forbidden_patterns = []
matched = []
failed = []
forbidden_found = []
# Check expected patterns
for pattern in expected_patterns:
try:
if re.search(pattern, generated, re.MULTILINE):
matched.append(pattern)
else:
failed.append(pattern)
except re.error:
# Invalid pattern, treat as literal
if pattern in generated:
matched.append(pattern)
else:
failed.append(pattern)
# Check forbidden patterns
for pattern in forbidden_patterns:
try:
if re.search(pattern, generated, re.MULTILINE):
forbidden_found.append(pattern)
except re.error:
if pattern in generated:
forbidden_found.append(pattern)
# Calculate score
if expected_patterns:
score = len(matched) / len(expected_patterns)
else:
score = 1.0
# Penalize forbidden patterns
if forbidden_found:
score *= 0.5
passed = len(matched) > 0 and len(forbidden_found) == 0
return passed, score, matched, failed, forbidden_found
def calculate_bracket_accuracy(results: List[TestResult]) -> float:
"""Calculates accuracy specific to brackets."""
bracket_tests = [r for r in results if r.category == 'brackets']
if not bracket_tests:
return 0.0
return sum(1 for t in bracket_tests if t.passed) / len(bracket_tests)
def calculate_indentation_accuracy(results: List[TestResult]) -> float:
"""Calculates accuracy specific to indentation."""
indent_tests = [r for r in results if r.category == 'indentation']
if not indent_tests:
return 0.0
return sum(1 for t in indent_tests if t.passed) / len(indent_tests)
def generate_report(
model_name: str,
results: List[TestResult]
) -> ValidationReport:
"""
Generates complete validation report.
"""
# Group by category
categories = {}
for result in results:
if result.category not in categories:
categories[result.category] = []
categories[result.category].append(result)
# Calculate results per category
category_results = {}
for cat, cat_results in categories.items():
passed = sum(1 for r in cat_results if r.passed)
category_results[cat] = CategoryResult(
category=cat,
total_tests=len(cat_results),
passed_tests=passed,
accuracy=passed / len(cat_results) if cat_results else 0,
test_results=cat_results
)
# Calculate general metrics
total = len(results)
passed = sum(1 for r in results if r.passed)
return ValidationReport(
model_name=model_name,
total_tests=total,
total_passed=passed,
overall_accuracy=passed / total if total > 0 else 0,
category_results=category_results,
bracket_accuracy=calculate_bracket_accuracy(results),
indentation_accuracy=calculate_indentation_accuracy(results),
structure_accuracy=sum(1 for r in results if r.category == 'structure' and r.passed) /
max(1, len([r for r in results if r.category == 'structure']))
)
def format_report(report: ValidationReport) -> str:
"""Formats report for printing."""
lines = [
"=" * 60,
f"πŸ“Š VALIDATION REPORT: {report.model_name}",
"=" * 60,
"",
f"πŸ“ˆ OVERALL RESULTS",
f" Total tests: {report.total_tests}",
f" Passed tests: {report.total_passed}",
f" Overall Accuracy: {report.overall_accuracy:.1%}",
"",
"πŸ“‹ SPECIFIC METRICS",
f" Bracket Accuracy: {report.bracket_accuracy:.1%}",
f" Indentation Accuracy: {report.indentation_accuracy:.1%}",
f" Structure Accuracy: {report.structure_accuracy:.1%}",
"",
"πŸ“ RESULTS BY CATEGORY",
]
for cat_name, cat_result in report.category_results.items():
status = "βœ…" if cat_result.accuracy >= 0.7 else "⚠️" if cat_result.accuracy >= 0.5 else "❌"
lines.append(f" {status} {cat_name}: {cat_result.passed_tests}/{cat_result.total_tests} ({cat_result.accuracy:.1%})")
lines.extend([
"",
"=" * 60
])
return "\n".join(lines)
if __name__ == '__main__':
# Function tests
print("πŸ§ͺ Testing metrics...")
# Bracket test
is_bal, msg = check_brackets_balanced("def foo(a, b):")
print(f"Balanced '(a, b)': {is_bal} - {msg}")
is_bal, msg = check_brackets_balanced("def foo(a, b:")
print(f"Balanced '(a, b:': {is_bal} - {msg}")
# Evaluation test
passed, score, matched, failed, forbidden = evaluate_test_case(
prompt="def hello(",
generated="name):\n print(name)",
expected_patterns=[r"\)", r":"]
)
print(f"Test result: passed={passed}, score={score}, matched={matched}")