""" metrics.py - Evaluation metrics for code completion validation. Implement functions to calculate bracket accuracy, indentation, and other code-specific metrics. """ import re from typing import List, Tuple, Dict from dataclasses import dataclass from collections import Counter @dataclass class TestResult: """Individual test result.""" test_name: str category: str passed: bool prompt: str generated: str expected_patterns: List[str] matched_patterns: List[str] failed_patterns: List[str] forbidden_matches: List[str] score: float # 0.0 to 1.0 @dataclass class CategoryResult: """Aggregated result for a category.""" category: str total_tests: int passed_tests: int accuracy: float test_results: List[TestResult] @dataclass class ValidationReport: """Complete validation report.""" model_name: str total_tests: int total_passed: int overall_accuracy: float category_results: Dict[str, CategoryResult] bracket_accuracy: float indentation_accuracy: float structure_accuracy: float def check_brackets_balanced(text: str) -> Tuple[bool, str]: """ Checks if brackets are balanced. Returns: (is_balanced, error_message) """ stack = [] pairs = {'(': ')', '[': ']', '{': '}'} for i, char in enumerate(text): if char in pairs: stack.append((char, i)) elif char in pairs.values(): if not stack: return False, f"Extra bracket '{char}' at position {i}" opening, pos = stack.pop() if pairs[opening] != char: return False, f"Mismatch: '{opening}' at position {pos} closed with '{char}' at position {i}" if stack: unclosed = [(char, pos) for char, pos in stack] return False, f"Unclosed brackets: {unclosed}" return True, "OK" def count_bracket_errors(prompt: str, generated: str) -> Dict[str, int]: """ Counts bracket errors in generated code. Returns: Dictionary with error counts by type """ full_code = prompt + generated errors = { 'unclosed_parens': 0, 'unclosed_brackets': 0, 'unclosed_braces': 0, 'extra_closing': 0 } # Count open and close parens = full_code.count('(') - full_code.count(')') brackets = full_code.count('[') - full_code.count(']') braces = full_code.count('{') - full_code.count('}') if parens > 0: errors['unclosed_parens'] = parens elif parens < 0: errors['extra_closing'] += abs(parens) if brackets > 0: errors['unclosed_brackets'] = brackets elif brackets < 0: errors['extra_closing'] += abs(brackets) if braces > 0: errors['unclosed_braces'] = braces elif braces < 0: errors['extra_closing'] += abs(braces) return errors def check_indentation(text: str) -> Dict[str, any]: """ Analyzes indentation quality in code. Returns: Dictionary with indentation metrics """ lines = text.split('\n') stats = { 'total_lines': len(lines), 'indented_lines': 0, 'consistent_indent': True, 'indent_style': None, # 'spaces' or 'tabs' 'indent_size': None, 'indent_errors': [] } indent_sizes = [] for i, line in enumerate(lines): if not line.strip(): # Empty line continue # Count leading whitespace stripped = line.lstrip() indent = len(line) - len(stripped) if indent > 0: stats['indented_lines'] += 1 # Detect style if line.startswith('\t'): if stats['indent_style'] is None: stats['indent_style'] = 'tabs' elif stats['indent_style'] == 'spaces': stats['consistent_indent'] = False else: if stats['indent_style'] is None: stats['indent_style'] = 'spaces' elif stats['indent_style'] == 'tabs': stats['consistent_indent'] = False if stats['indent_style'] == 'spaces': indent_sizes.append(indent) # Determine most common indent size if indent_sizes: # Find GCD of indent sizes common_indents = Counter(indent_sizes) stats['indent_size'] = min(common_indents.keys()) if common_indents else 4 return stats def evaluate_test_case( prompt: str, generated: str, expected_patterns: List[str], forbidden_patterns: List[str] = None ) -> Tuple[bool, float, List[str], List[str], List[str]]: """ Evaluates a test case. Returns: (passed, score, matched_patterns, failed_patterns, forbidden_matches) """ if forbidden_patterns is None: forbidden_patterns = [] matched = [] failed = [] forbidden_found = [] # Check expected patterns for pattern in expected_patterns: try: if re.search(pattern, generated, re.MULTILINE): matched.append(pattern) else: failed.append(pattern) except re.error: # Invalid pattern, treat as literal if pattern in generated: matched.append(pattern) else: failed.append(pattern) # Check forbidden patterns for pattern in forbidden_patterns: try: if re.search(pattern, generated, re.MULTILINE): forbidden_found.append(pattern) except re.error: if pattern in generated: forbidden_found.append(pattern) # Calculate score if expected_patterns: score = len(matched) / len(expected_patterns) else: score = 1.0 # Penalize forbidden patterns if forbidden_found: score *= 0.5 passed = len(matched) > 0 and len(forbidden_found) == 0 return passed, score, matched, failed, forbidden_found def calculate_bracket_accuracy(results: List[TestResult]) -> float: """Calculates accuracy specific to brackets.""" bracket_tests = [r for r in results if r.category == 'brackets'] if not bracket_tests: return 0.0 return sum(1 for t in bracket_tests if t.passed) / len(bracket_tests) def calculate_indentation_accuracy(results: List[TestResult]) -> float: """Calculates accuracy specific to indentation.""" indent_tests = [r for r in results if r.category == 'indentation'] if not indent_tests: return 0.0 return sum(1 for t in indent_tests if t.passed) / len(indent_tests) def generate_report( model_name: str, results: List[TestResult] ) -> ValidationReport: """ Generates complete validation report. """ # Group by category categories = {} for result in results: if result.category not in categories: categories[result.category] = [] categories[result.category].append(result) # Calculate results per category category_results = {} for cat, cat_results in categories.items(): passed = sum(1 for r in cat_results if r.passed) category_results[cat] = CategoryResult( category=cat, total_tests=len(cat_results), passed_tests=passed, accuracy=passed / len(cat_results) if cat_results else 0, test_results=cat_results ) # Calculate general metrics total = len(results) passed = sum(1 for r in results if r.passed) return ValidationReport( model_name=model_name, total_tests=total, total_passed=passed, overall_accuracy=passed / total if total > 0 else 0, category_results=category_results, bracket_accuracy=calculate_bracket_accuracy(results), indentation_accuracy=calculate_indentation_accuracy(results), structure_accuracy=sum(1 for r in results if r.category == 'structure' and r.passed) / max(1, len([r for r in results if r.category == 'structure'])) ) def format_report(report: ValidationReport) -> str: """Formats report for printing.""" lines = [ "=" * 60, f"๐Ÿ“Š VALIDATION REPORT: {report.model_name}", "=" * 60, "", f"๐Ÿ“ˆ OVERALL RESULTS", f" Total tests: {report.total_tests}", f" Passed tests: {report.total_passed}", f" Overall Accuracy: {report.overall_accuracy:.1%}", "", "๐Ÿ“‹ SPECIFIC METRICS", f" Bracket Accuracy: {report.bracket_accuracy:.1%}", f" Indentation Accuracy: {report.indentation_accuracy:.1%}", f" Structure Accuracy: {report.structure_accuracy:.1%}", "", "๐Ÿ“ RESULTS BY CATEGORY", ] for cat_name, cat_result in report.category_results.items(): status = "โœ…" if cat_result.accuracy >= 0.7 else "โš ๏ธ" if cat_result.accuracy >= 0.5 else "โŒ" lines.append(f" {status} {cat_name}: {cat_result.passed_tests}/{cat_result.total_tests} ({cat_result.accuracy:.1%})") lines.extend([ "", "=" * 60 ]) return "\n".join(lines) if __name__ == '__main__': # Function tests print("๐Ÿงช Testing metrics...") # Bracket test is_bal, msg = check_brackets_balanced("def foo(a, b):") print(f"Balanced '(a, b)': {is_bal} - {msg}") is_bal, msg = check_brackets_balanced("def foo(a, b:") print(f"Balanced '(a, b:': {is_bal} - {msg}") # Evaluation test passed, score, matched, failed, forbidden = evaluate_test_case( prompt="def hello(", generated="name):\n print(name)", expected_patterns=[r"\)", r":"] ) print(f"Test result: passed={passed}, score={score}, matched={matched}")