File size: 2,114 Bytes
ac9ddbb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
"""Utility functions for model training and evaluation."""

import os
from typing import List

LANGS: List[str] = ["java", "python", "pharo"]


def load_dataset_splits(base_dir=None, langs=None):
    """Load dataset splits from CSV files under data/raw.

    Expects files like data/raw/java_train.csv, data/raw/java_test.csv, etc.
    Returns a dict mapping split names (e.g. "java_test") to pandas DataFrames.

    Raises:
        FileNotFoundError: se la directory base o un file atteso non esiste.
        ImportError: se pandas non è installato.

    """
    if base_dir is None:
        base_dir = os.path.join("data", "raw")

    if langs is None:
        langs = LANGS

    if not os.path.isdir(base_dir):
        raise FileNotFoundError(
            f"CSV datasets not found under {base_dir}; cannot load dataset splits."
        )

    try:
        import pandas as pd
    except Exception as e:
        raise ImportError("pandas is required to load dataset splits") from e

    datasets = {}
    for lang in langs:
        for split in ("train", "test"):
            fname = f"{lang}_{split}.csv"
            path = os.path.join(base_dir, fname)
            if not os.path.isfile(path):
                raise FileNotFoundError(f"Expected dataset file missing: {path}")
            df = pd.read_csv(path)
            datasets[f"{lang}_{split}"] = df

    return datasets


def parse_labels_column(df):
    """Parse the 'labels' column of a DataFrame into lists of integers."""

    def _parse_one(x):
        if isinstance(x, str):
            s = x.strip()
            if s.startswith("[") and s.endswith("]"):
                s = s[1:-1]
            return [int(tok) for tok in s.split() if tok]
        try:
            import numpy as np

            if isinstance(x, np.ndarray):
                return [int(v) for v in x.tolist()]
        except ImportError:
            pass
        if isinstance(x, (list, tuple)):
            return [int(v) for v in x]
        raise ValueError(f"Formato labels non gestito: {type(x)} -> {x!r}")

    df["labels"] = df["labels"].apply(_parse_one)
    return df