|
|
""" |
|
|
Data processing utilities for the Coding Expert model |
|
|
""" |
|
|
import json |
|
|
import os |
|
|
from pathlib import Path |
|
|
import jsonlines |
|
|
from typing import Dict, List, Any, Optional, Tuple |
|
|
import hashlib |
|
|
import datetime |
|
|
import logging |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
from datasets import Dataset |
|
|
from tqdm import tqdm |
|
|
import ast |
|
|
import re |
|
|
from collections import Counter |
|
|
|
|
|
class CodeDataProcessor: |
|
|
def __init__(self, output_dir: str = "processed_data"): |
|
|
self.output_dir = Path(output_dir) |
|
|
self.output_dir.mkdir(exist_ok=True) |
|
|
self.logger = self._setup_logger() |
|
|
|
|
|
def _setup_logger(self) -> logging.Logger: |
|
|
"""Setup logging specific to code processing""" |
|
|
logger = logging.getLogger(__name__) |
|
|
logger.setLevel(logging.INFO) |
|
|
handler = logging.StreamHandler() |
|
|
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
|
|
handler.setFormatter(formatter) |
|
|
logger.addHandler(handler) |
|
|
return logger |
|
|
|
|
|
def process_code(self, code: str, language: str = "python") -> Dict[str, Any]: |
|
|
"""Process and analyze code snippet""" |
|
|
try: |
|
|
|
|
|
code = self._clean_code(code) |
|
|
|
|
|
|
|
|
ast_info = self._parse_ast(code, language) |
|
|
|
|
|
|
|
|
metrics = self._extract_code_metrics(code, ast_info) |
|
|
|
|
|
|
|
|
patterns = self._identify_patterns(code) |
|
|
|
|
|
return { |
|
|
"code": code, |
|
|
"language": language, |
|
|
"ast_info": ast_info, |
|
|
"metrics": metrics, |
|
|
"patterns": patterns |
|
|
} |
|
|
except Exception as e: |
|
|
self.logger.warning(f"Error processing code: {str(e)}") |
|
|
return {"error": str(e)} |
|
|
|
|
|
def _clean_code(self, code: str) -> str: |
|
|
"""Clean code by removing unnecessary whitespace and comments""" |
|
|
|
|
|
code = code.strip() |
|
|
|
|
|
|
|
|
lines = [line.strip() for line in code.split('\n') if line.strip()] |
|
|
code = '\n'.join(lines) |
|
|
|
|
|
return code |
|
|
|
|
|
def _parse_ast(self, code: str, language: str) -> Dict[str, Any]: |
|
|
"""Parse code into AST and extract structure""" |
|
|
try: |
|
|
if language == "python": |
|
|
tree = ast.parse(code) |
|
|
return { |
|
|
"num_functions": len([node for node in ast.walk(tree) if isinstance(node, ast.FunctionDef)]), |
|
|
"num_classes": len([node for node in ast.walk(tree) if isinstance(node, ast.ClassDef)]), |
|
|
"complexity": self._calculate_complexity(tree) |
|
|
} |
|
|
return {} |
|
|
except Exception as e: |
|
|
return {"error": str(e)} |
|
|
|
|
|
def _calculate_complexity(self, tree: ast.AST) -> int: |
|
|
"""Calculate cyclomatic complexity""" |
|
|
complexity = 1 |
|
|
for node in ast.walk(tree): |
|
|
if isinstance(node, (ast.If, ast.For, ast.While, ast.Try, ast.ExceptHandler)): |
|
|
complexity += 1 |
|
|
return complexity |
|
|
|
|
|
def _extract_code_metrics(self, code: str, ast_info: Dict[str, Any]) -> Dict[str, Any]: |
|
|
"""Extract various code metrics""" |
|
|
metrics = { |
|
|
"length": len(code), |
|
|
"lines": len(code.split('\n')), |
|
|
"tokens": len(code.split()), |
|
|
"unique_tokens": len(set(code.split())), |
|
|
"ast_complexity": ast_info.get("complexity", 0), |
|
|
"function_count": ast_info.get("num_functions", 0), |
|
|
"class_count": ast_info.get("num_classes", 0) |
|
|
} |
|
|
|
|
|
|
|
|
tokens = code.split() |
|
|
token_dist = Counter(tokens) |
|
|
metrics["token_distribution"] = token_dist.most_common(5) |
|
|
|
|
|
return metrics |
|
|
|
|
|
def _identify_patterns(self, code: str) -> Dict[str, List[str]]: |
|
|
"""Identify common code patterns and anti-patterns""" |
|
|
patterns = { |
|
|
"design_patterns": [], |
|
|
"anti_patterns": [], |
|
|
"security_issues": [] |
|
|
} |
|
|
|
|
|
|
|
|
if "class" in code and "def" in code: |
|
|
patterns["design_patterns"].append("Class-based design") |
|
|
|
|
|
|
|
|
if "global" in code: |
|
|
patterns["anti_patterns"].append("Global variables") |
|
|
|
|
|
|
|
|
if "eval(" in code: |
|
|
patterns["security_issues"].append("Eval usage") |
|
|
|
|
|
return patterns |
|
|
|
|
|
def process_dataset(self, dataset: Dataset, dataset_name: str) -> List[Dict[str, Any]]: |
|
|
"""Process a complete dataset""" |
|
|
processed = [] |
|
|
error_count = 0 |
|
|
|
|
|
self.logger.info(f"Processing {dataset_name} dataset with {len(dataset)} samples") |
|
|
|
|
|
for idx, example in enumerate(tqdm(dataset, desc=f"Processing {dataset_name}")): |
|
|
try: |
|
|
processed_example = self._process_example(example, dataset_name) |
|
|
processed.append(processed_example) |
|
|
except Exception as e: |
|
|
error_count += 1 |
|
|
self.logger.error(f"Error processing example {idx} in {dataset_name}: {str(e)}") |
|
|
|
|
|
self.logger.info(f"Processed {len(processed)} examples") |
|
|
self.logger.info(f"Encountered {error_count} errors") |
|
|
|
|
|
return processed |
|
|
|
|
|
def _process_example(self, example: Dict[str, Any], dataset_name: str) -> Dict[str, Any]: |
|
|
"""Process a single example based on dataset type""" |
|
|
if dataset_name == "CodeSearchNet": |
|
|
return self._process_code_search_net(example) |
|
|
elif dataset_name == "HumanEval": |
|
|
return self._process_human_eval(example) |
|
|
elif dataset_name == "MBPP": |
|
|
return self._process_mbpp(example) |
|
|
elif dataset_name == "Spider": |
|
|
return self._process_spider(example) |
|
|
elif dataset_name == "DeepFix": |
|
|
return self._process_deep_fix(example) |
|
|
elif dataset_name == "CodeXGLUE": |
|
|
return self._process_codexglue(example) |
|
|
else: |
|
|
raise ValueError(f"Unknown dataset: {dataset_name}") |
|
|
|
|
|
def _process_code_search_net(self, example: Dict[str, Any]) -> Dict[str, Any]: |
|
|
"""Process CodeSearchNet example""" |
|
|
return { |
|
|
"code": example["code"].strip(), |
|
|
"docstring": example["docstring"].strip(), |
|
|
"language": example["language"], |
|
|
"function_name": example["function_name"], |
|
|
"code_analysis": self.process_code(example["code"]) |
|
|
} |
|
|
|
|
|
def _process_human_eval(self, example: Dict[str, Any]) -> Dict[str, Any]: |
|
|
"""Process HumanEval example""" |
|
|
return { |
|
|
"task_id": example["task_id"], |
|
|
"prompt": example["prompt"].strip(), |
|
|
"solution": example["canonical_solution"].strip(), |
|
|
"test": example["test"].strip(), |
|
|
"entry_point": example["entry_point"], |
|
|
"code_analysis": self.process_code(example["canonical_solution"]) |
|
|
} |
|
|
|
|
|
def _process_mbpp(self, example: Dict[str, Any]) -> Dict[str, Any]: |
|
|
"""Process MBPP example""" |
|
|
return { |
|
|
"task_id": example["task_id"], |
|
|
"problem": example["text"].strip(), |
|
|
"solution": example["code"].strip(), |
|
|
"test_list": example["test_list"], |
|
|
"challenge_test_list": example["challenge_test_list"], |
|
|
"code_analysis": self.process_code(example["code"]) |
|
|
} |
|
|
|
|
|
def _process_spider(self, example: Dict[str, Any]) -> Dict[str, Any]: |
|
|
"""Process Spider example""" |
|
|
return { |
|
|
"query": example["query"].strip(), |
|
|
"question": example["question"].strip(), |
|
|
"db_id": example["db_id"], |
|
|
"sql": example["sql"].strip(), |
|
|
"code_analysis": self.process_code(example["sql"]) |
|
|
} |
|
|
|
|
|
def _process_deep_fix(self, example: Dict[str, Any]) -> Dict[str, Any]: |
|
|
"""Process DeepFix example""" |
|
|
return { |
|
|
"original_code": example["code"].strip(), |
|
|
"fixed_code": example["fixed_code"].strip(), |
|
|
"error_type": example["error_type"], |
|
|
"code_analysis": self.process_code(example["fixed_code"]) |
|
|
} |
|
|
|
|
|
def _process_codexglue(self, example: Dict[str, Any]) -> Dict[str, Any]: |
|
|
"""Process CodeXGLUE example""" |
|
|
return { |
|
|
"code": example["code"].strip(), |
|
|
"docstring": example["docstring"].strip(), |
|
|
"task": example["task"], |
|
|
"language": example["language"], |
|
|
"code_analysis": self.process_code(example["code"]) |
|
|
} |
|
|
|
|
|
def save_to_jsonl(self, data: List[Dict[str, Any]], filename: str) -> Path: |
|
|
"""Save processed data to JSONL file""" |
|
|
filepath = self.output_dir / filename |
|
|
with jsonlines.open(filepath, mode='w') as writer: |
|
|
writer.write_all(data) |
|
|
self.logger.info(f"Saved data to {filepath}") |
|
|
return filepath |
|
|
|
|
|
def print_sample(self, data: List[Dict[str, Any]], count: int = 3): |
|
|
"""Print sample of processed data""" |
|
|
self.logger.info("\nSample data:") |
|
|
for i, example in enumerate(data[:count]): |
|
|
self.logger.info(f"\nSample {i+1}:") |
|
|
self.logger.info(json.dumps(example, indent=2)) |
|
|
|
|
|
def print_memory_usage(self): |
|
|
"""Print current memory usage""" |
|
|
process = psutil.Process() |
|
|
memory_info = process.memory_info() |
|
|
self.logger.info(f"Current memory usage: {memory_info.rss / 1024 / 1024:.2f} MB") |
|
|
|