|
|
""" |
|
|
Data preparation script for the Coding Expert model |
|
|
""" |
|
|
import os |
|
|
import json |
|
|
from pathlib import Path |
|
|
import jsonlines |
|
|
from typing import Dict, List, Any |
|
|
import sys |
|
|
import psutil |
|
|
from datasets import load_dataset |
|
|
import ast |
|
|
import numpy as np |
|
|
|
|
|
from data_processor import CodeDataProcessor |
|
|
|
|
|
class CodeDataPreparer: |
|
|
def __init__(self, output_dir: str = "processed_data"): |
|
|
self.output_dir = Path(output_dir) |
|
|
self.output_dir.mkdir(exist_ok=True) |
|
|
self.datasets = { |
|
|
"CodeSearchNet": { |
|
|
"source": "codeium/codeium", |
|
|
"split": "train", |
|
|
"fields": ["code", "docstring", "language", "function_name"] |
|
|
}, |
|
|
"HumanEval": { |
|
|
"source": "openai/human_eval", |
|
|
"split": "test", |
|
|
"fields": ["task_id", "prompt", "canonical_solution", "test", "entry_point"] |
|
|
}, |
|
|
"MBPP": { |
|
|
"source": "mbpp/mbpp", |
|
|
"split": "train", |
|
|
"fields": ["task_id", "text", "code", "test_list", "challenge_test_list"] |
|
|
}, |
|
|
"Spider": { |
|
|
"source": "yale-lily/spider", |
|
|
"split": "train", |
|
|
"fields": ["query", "question", "db_id", "sql"] |
|
|
}, |
|
|
"DeepFix": { |
|
|
"source": "deepfix/deepfix", |
|
|
"split": "train", |
|
|
"fields": ["code", "fixed_code", "error_type"] |
|
|
}, |
|
|
"CodeXGLUE": { |
|
|
"source": "microsoft/CodeXGLUE", |
|
|
"split": "train", |
|
|
"fields": ["code", "docstring", "task", "language"] |
|
|
} |
|
|
} |
|
|
|
|
|
def process_dataset(self, dataset: List[Dict[str, Any]], dataset_name: str) -> List[Dict[str, Any]]: |
|
|
"""Process a specific dataset""" |
|
|
processed = [] |
|
|
error_count = 0 |
|
|
|
|
|
print(f"\nProcessing {dataset_name} dataset...") |
|
|
|
|
|
for idx, example in enumerate(dataset): |
|
|
try: |
|
|
processed_example = self._process_example(dataset_name, example) |
|
|
processed.append(processed_example) |
|
|
except Exception as e: |
|
|
print(f"Error processing example {idx} in {dataset_name}: {str(e)}") |
|
|
error_count += 1 |
|
|
|
|
|
print(f"Processed {len(processed)} examples from {dataset_name}") |
|
|
print(f"Encountered {error_count} errors during processing") |
|
|
return processed |
|
|
|
|
|
def _process_example(self, dataset_name: str, example: Dict[str, Any]) -> Dict[str, Any]: |
|
|
"""Process a single example based on its dataset type""" |
|
|
if dataset_name == "CodeSearchNet": |
|
|
return self._process_code_search_net(example) |
|
|
elif dataset_name == "HumanEval": |
|
|
return self._process_human_eval(example) |
|
|
elif dataset_name == "MBPP": |
|
|
return self._process_mbpp(example) |
|
|
elif dataset_name == "Spider": |
|
|
return self._process_spider(example) |
|
|
elif dataset_name == "DeepFix": |
|
|
return self._process_deep_fix(example) |
|
|
elif dataset_name == "CodeXGLUE": |
|
|
return self._process_codexglue(example) |
|
|
else: |
|
|
raise ValueError(f"Unknown dataset: {dataset_name}") |
|
|
|
|
|
def _process_code_search_net(self, example: Dict[str, Any]) -> Dict[str, Any]: |
|
|
"""Process CodeSearchNet example""" |
|
|
return { |
|
|
"code": example["code"].strip(), |
|
|
"docstring": example["docstring"].strip(), |
|
|
"language": example["language"], |
|
|
"function_name": example["function_name"], |
|
|
"code_analysis": self._analyze_code(example["code"]) |
|
|
} |
|
|
|
|
|
def _process_human_eval(self, example: Dict[str, Any]) -> Dict[str, Any]: |
|
|
"""Process HumanEval example""" |
|
|
return { |
|
|
"task_id": example["task_id"], |
|
|
"prompt": example["prompt"].strip(), |
|
|
"solution": example["canonical_solution"].strip(), |
|
|
"test": example["test"].strip(), |
|
|
"entry_point": example["entry_point"], |
|
|
"code_analysis": self._analyze_code(example["canonical_solution"]) |
|
|
} |
|
|
|
|
|
def _process_mbpp(self, example: Dict[str, Any]) -> Dict[str, Any]: |
|
|
"""Process MBPP example""" |
|
|
return { |
|
|
"task_id": example["task_id"], |
|
|
"problem": example["text"].strip(), |
|
|
"solution": example["code"].strip(), |
|
|
"test_list": example["test_list"], |
|
|
"challenge_test_list": example["challenge_test_list"], |
|
|
"code_analysis": self._analyze_code(example["code"]) |
|
|
} |
|
|
|
|
|
def _process_spider(self, example: Dict[str, Any]) -> Dict[str, Any]: |
|
|
"""Process Spider example""" |
|
|
return { |
|
|
"query": example["query"].strip(), |
|
|
"question": example["question"].strip(), |
|
|
"db_id": example["db_id"], |
|
|
"sql": example["sql"].strip(), |
|
|
"code_analysis": self._analyze_code(example["sql"]) |
|
|
} |
|
|
|
|
|
def _process_deep_fix(self, example: Dict[str, Any]) -> Dict[str, Any]: |
|
|
"""Process DeepFix example""" |
|
|
return { |
|
|
"original_code": example["code"].strip(), |
|
|
"fixed_code": example["fixed_code"].strip(), |
|
|
"error_type": example["error_type"], |
|
|
"code_analysis": self._analyze_code(example["fixed_code"]) |
|
|
} |
|
|
|
|
|
def _process_codexglue(self, example: Dict[str, Any]) -> Dict[str, Any]: |
|
|
"""Process CodeXGLUE example""" |
|
|
return { |
|
|
"code": example["code"].strip(), |
|
|
"docstring": example["docstring"].strip(), |
|
|
"task": example["task"], |
|
|
"language": example["language"], |
|
|
"code_analysis": self._analyze_code(example["code"]) |
|
|
} |
|
|
|
|
|
def _analyze_code(self, code: str) -> Dict[str, Any]: |
|
|
"""Analyze code structure and complexity""" |
|
|
try: |
|
|
tree = ast.parse(code) |
|
|
return { |
|
|
"num_functions": len([node for node in ast.walk(tree) if isinstance(node, ast.FunctionDef)]), |
|
|
"num_classes": len([node for node in ast.walk(tree) if isinstance(node, ast.ClassDef)]), |
|
|
"complexity": self._calculate_complexity(tree) |
|
|
} |
|
|
except Exception as e: |
|
|
return {"error": str(e)} |
|
|
|
|
|
def _calculate_complexity(self, tree: ast.AST) -> int: |
|
|
"""Calculate cyclomatic complexity""" |
|
|
complexity = 1 |
|
|
for node in ast.walk(tree): |
|
|
if isinstance(node, (ast.If, ast.For, ast.While, ast.Try, ast.ExceptHandler)): |
|
|
complexity += 1 |
|
|
return complexity |
|
|
|
|
|
def save_to_jsonl(self, data: List[Dict[str, Any]], filename: str): |
|
|
"""Save data to JSONL file""" |
|
|
filepath = self.output_dir / filename |
|
|
with jsonlines.open(filepath, mode='w') as writer: |
|
|
writer.write_all(data) |
|
|
return filepath |
|
|
|
|
|
def print_sample(self, data: List[Dict[str, Any]], count: int = 3): |
|
|
"""Print sample of processed data""" |
|
|
print("\nSample data:") |
|
|
for i, example in enumerate(data[:count]): |
|
|
print(f"\nSample {i+1}:") |
|
|
print(json.dumps(example, indent=2)) |
|
|
|
|
|
def print_memory_usage(self): |
|
|
"""Print current memory usage""" |
|
|
process = psutil.Process() |
|
|
memory_info = process.memory_info() |
|
|
print(f"Current memory usage: {memory_info.rss / 1024 / 1024:.2f} MB") |
|
|
|
|
|
def main(): |
|
|
preparer = CodeDataPreparer() |
|
|
|
|
|
|
|
|
for dataset_name, config in preparer.datasets.items(): |
|
|
try: |
|
|
print(f"\nLoading {dataset_name} dataset...") |
|
|
dataset = load_dataset(config["source"], split=config["split"]) |
|
|
print(f"Loaded {len(dataset)} samples from {dataset_name}") |
|
|
|
|
|
processed_data = preparer.process_dataset(dataset, dataset_name) |
|
|
print(f"Processed {len(processed_data)} samples") |
|
|
|
|
|
preparer.print_sample(processed_data) |
|
|
|
|
|
|
|
|
output_path = preparer.save_to_jsonl( |
|
|
processed_data, |
|
|
f"{dataset_name.lower()}_processed.jsonl" |
|
|
) |
|
|
print(f"\nSaved {dataset_name} data to: {output_path}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error processing {dataset_name} dataset: {str(e)}") |
|
|
print("Continuing with next dataset...") |
|
|
|
|
|
|
|
|
preparer.print_memory_usage() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|