File size: 3,574 Bytes
8394c6e
 
 
a13b6ec
8394c6e
 
 
 
 
 
 
 
 
a13b6ec
8394c6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a13b6ec
 
8394c6e
 
 
4ed2afe
 
 
 
 
 
 
 
 
 
 
 
 
8394c6e
 
3ef6520
 
 
 
 
 
 
 
 
 
 
8394c6e
3ef6520
 
 
 
 
a13b6ec
3ef6520
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ed2afe
eb46234
6d35419
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import os
import json
import ast
import tempfile
from git import Repo, exc

# --- Configuration ---
REPO_CONFIG = {
    "fastapi": "https://github.com/tiangolo/fastapi.git",
    "requests": "https://github.com/psf/requests.git",
    "scikit-learn": "https://github.com/scikit-learn/scikit-learn.git"
}
OUTPUT_FILE = "diff_dataset.jsonl"
MAX_COMMITS_PER_REPO = 5000

class FuncParser(ast.NodeVisitor):
    def __init__(self):
        self.functions = {}
    def visit_FunctionDef(self, node):
        docstring = ast.get_docstring(node) or ""
        self.functions[node.name] = docstring
        self.generic_visit(node)

def get_functions_from_source(source_code):
    try:
        tree = ast.parse(source_code)
        parser = FuncParser()
        parser.visit(tree)
        return parser.functions
    except SyntaxError:
        return {}

def format_for_model(diff_text, old_doc, new_doc):
    return {
        "text": f"""### INSTRUCTION:
A Python function's code was changed. Based on the `git diff` provided, update the function's documentation.

### GIT DIFF:
```diff
{diff_text}
OLD DOCUMENTATION:

{old_doc.strip()}
UPDATED DOCUMENTATION:

{new_doc.strip()}
"""
}

def main():
    dataset = []
    base_repo_dir = tempfile.mkdtemp()
    print(f"Using temporary directory for clones: {base_repo_dir}")
    
    for name, url in REPO_CONFIG.items():
        repo_dir = os.path.join(base_repo_dir, name)
        try:
            print(f"Cloning {name} from {url}...")
            repo = Repo.clone_from(url, repo_dir)
        except exc.GitCommandError as e:
            print(f"Error cloning {name}: {e}")
            continue
    
        print(f"Mining commit history for {name}...")
        commits = list(repo.iter_commits(max_count=MAX_COMMITS_PER_REPO))
        for commit in commits:
            if not commit.parents:
                continue
    
            parent = commit.parents[0]
            diffs = commit.diff(parent, create_patch=True, unified=0)
    
            for diff in diffs:
                if not (diff.a_path and diff.b_path and diff.a_path.endswith('.py') and diff.b_path.endswith('.py')):
                    continue
                if diff.a_blob is None or diff.b_blob is None:
                    continue
    
                try:
                    old_source = diff.a_blob.data_stream.read().decode('utf-8')
                    new_source = diff.b_blob.data_stream.read().decode('utf-8')
                except UnicodeDecodeError:
                    continue
                
                old_funcs = get_functions_from_source(old_source)
                new_funcs = get_functions_from_source(new_source)
    
                for func_name, old_doc in old_funcs.items():
                    if func_name in new_funcs:
                        new_doc = new_funcs[func_name]
                        if old_doc != new_doc and len(old_doc) > 20 and len(new_doc) > 20:
                            diff_text = diff.diff.decode('utf-8', errors='ignore')
                            formatted_example = format_for_model(diff_text, old_doc, new_doc)
                            dataset.append(formatted_example)
    
    print(f"\nFound {len(dataset)} high-quality examples.")
    
    try:
        with open(OUTPUT_FILE, 'w') as f:
            for item in dataset:
                f.write(json.dumps(item) + "\n")
        print(f"Dataset successfully saved to '{OUTPUT_FILE}'.")
    except Exception as e:
        print(f"FATAL: Could not write final dataset file to {OUTPUT_FILE}. Error: {e}")

if __name__ == "main": 
    main()