Spaces:
Runtime error
Runtime error
File size: 3,574 Bytes
8394c6e a13b6ec 8394c6e a13b6ec 8394c6e a13b6ec 8394c6e 4ed2afe 8394c6e 3ef6520 8394c6e 3ef6520 a13b6ec 3ef6520 4ed2afe eb46234 6d35419 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 | import os
import json
import ast
import tempfile
from git import Repo, exc
# --- Configuration ---
REPO_CONFIG = {
"fastapi": "https://github.com/tiangolo/fastapi.git",
"requests": "https://github.com/psf/requests.git",
"scikit-learn": "https://github.com/scikit-learn/scikit-learn.git"
}
OUTPUT_FILE = "diff_dataset.jsonl"
MAX_COMMITS_PER_REPO = 5000
class FuncParser(ast.NodeVisitor):
def __init__(self):
self.functions = {}
def visit_FunctionDef(self, node):
docstring = ast.get_docstring(node) or ""
self.functions[node.name] = docstring
self.generic_visit(node)
def get_functions_from_source(source_code):
try:
tree = ast.parse(source_code)
parser = FuncParser()
parser.visit(tree)
return parser.functions
except SyntaxError:
return {}
def format_for_model(diff_text, old_doc, new_doc):
return {
"text": f"""### INSTRUCTION:
A Python function's code was changed. Based on the `git diff` provided, update the function's documentation.
### GIT DIFF:
```diff
{diff_text}
OLD DOCUMENTATION:
{old_doc.strip()}
UPDATED DOCUMENTATION:
{new_doc.strip()}
"""
}
def main():
dataset = []
base_repo_dir = tempfile.mkdtemp()
print(f"Using temporary directory for clones: {base_repo_dir}")
for name, url in REPO_CONFIG.items():
repo_dir = os.path.join(base_repo_dir, name)
try:
print(f"Cloning {name} from {url}...")
repo = Repo.clone_from(url, repo_dir)
except exc.GitCommandError as e:
print(f"Error cloning {name}: {e}")
continue
print(f"Mining commit history for {name}...")
commits = list(repo.iter_commits(max_count=MAX_COMMITS_PER_REPO))
for commit in commits:
if not commit.parents:
continue
parent = commit.parents[0]
diffs = commit.diff(parent, create_patch=True, unified=0)
for diff in diffs:
if not (diff.a_path and diff.b_path and diff.a_path.endswith('.py') and diff.b_path.endswith('.py')):
continue
if diff.a_blob is None or diff.b_blob is None:
continue
try:
old_source = diff.a_blob.data_stream.read().decode('utf-8')
new_source = diff.b_blob.data_stream.read().decode('utf-8')
except UnicodeDecodeError:
continue
old_funcs = get_functions_from_source(old_source)
new_funcs = get_functions_from_source(new_source)
for func_name, old_doc in old_funcs.items():
if func_name in new_funcs:
new_doc = new_funcs[func_name]
if old_doc != new_doc and len(old_doc) > 20 and len(new_doc) > 20:
diff_text = diff.diff.decode('utf-8', errors='ignore')
formatted_example = format_for_model(diff_text, old_doc, new_doc)
dataset.append(formatted_example)
print(f"\nFound {len(dataset)} high-quality examples.")
try:
with open(OUTPUT_FILE, 'w') as f:
for item in dataset:
f.write(json.dumps(item) + "\n")
print(f"Dataset successfully saved to '{OUTPUT_FILE}'.")
except Exception as e:
print(f"FATAL: Could not write final dataset file to {OUTPUT_FILE}. Error: {e}")
if __name__ == "main":
main() |