File size: 2,943 Bytes
b89e6d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
"""Phase 1a: load raw CodeSearchNet (or the synthetic sample).

We normalise everything to a pandas DataFrame with two key columns:
  - docstring : natural-language description (model INPUT)
  - code      : function body (model TARGET)
plus useful metadata (language, repo, url) for traceability.
"""
from __future__ import annotations

import sys
from pathlib import Path

import pandas as pd

sys.path.append(str(Path(__file__).resolve().parents[2]))
from src.config import load_config  # noqa: E402
from src.data.make_sample import make_sample  # noqa: E402

# Map CodeSearchNet's verbose column names to our canonical names.
_COLUMN_MAP = {
    "func_documentation_string": "docstring",
    "func_code_string": "code",
    "language": "language",
    "repository_name": "repo",
    "func_code_url": "url",
}


def _from_huggingface(cfg) -> pd.DataFrame:
    """Stream CodeSearchNet per-language from HuggingFace and concatenate."""
    from datasets import load_dataset

    max_rows = getattr(cfg.data, "max_rows", 0)
    frames = []
    for lang in cfg.data.languages:
        print(f"[load] downloading CodeSearchNet '{lang}' ...")
        # CodeSearchNet ships train/validation/test; we pull all and re-split later.
        ds = load_dataset(cfg.data.hf_dataset_id, lang)
        for split in ds.keys():
            df = ds[split].to_pandas()
            keep = [c for c in _COLUMN_MAP if c in df.columns]
            df = df[keep].rename(columns=_COLUMN_MAP)
            frames.append(df)
    out = pd.concat(frames, ignore_index=True)
    print(f"[load] total raw rows: {len(out):,}")
    if max_rows and max_rows > 0 and len(out) > max_rows:
        out = out.sample(n=max_rows, random_state=42).reset_index(drop=True)
        print(f"[load] capped to {max_rows:,} rows (max_rows setting)")
    return out


def _from_sample(cfg) -> pd.DataFrame:
    print(f"[load] using synthetic sample (n={cfg.data.sample_size})")
    df = make_sample(cfg.data.sample_size, cfg.split.seed)
    keep = [c for c in _COLUMN_MAP if c in df.columns]
    return df[keep].rename(columns=_COLUMN_MAP)


def load_raw(cfg=None) -> pd.DataFrame:
    cfg = cfg or load_config()
    if cfg.data.use_sample:
        df = _from_sample(cfg)
    else:
        try:
            df = _from_huggingface(cfg)
        except Exception as e:  # noqa: BLE001
            print(
                f"[load] HuggingFace load failed ({e}).\n"
                f"        Tip: try hf_dataset_id: 'code-search-net/code_search_net' "
                f"in config.yaml, or set use_sample: true.",
                file=sys.stderr,
            )
            raise
    # Guarantee the columns downstream code expects, even if metadata missing.
    for col in ("docstring", "code", "language", "repo", "url"):
        if col not in df.columns:
            df[col] = ""
    return df


if __name__ == "__main__":
    df = load_raw()
    print(df.shape)
    print(df.head(3).to_string())