File size: 2,114 Bytes
ac9ddbb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
"""Utility functions for model training and evaluation."""
import os
from typing import List
LANGS: List[str] = ["java", "python", "pharo"]
def load_dataset_splits(base_dir=None, langs=None):
"""Load dataset splits from CSV files under data/raw.
Expects files like data/raw/java_train.csv, data/raw/java_test.csv, etc.
Returns a dict mapping split names (e.g. "java_test") to pandas DataFrames.
Raises:
FileNotFoundError: se la directory base o un file atteso non esiste.
ImportError: se pandas non è installato.
"""
if base_dir is None:
base_dir = os.path.join("data", "raw")
if langs is None:
langs = LANGS
if not os.path.isdir(base_dir):
raise FileNotFoundError(
f"CSV datasets not found under {base_dir}; cannot load dataset splits."
)
try:
import pandas as pd
except Exception as e:
raise ImportError("pandas is required to load dataset splits") from e
datasets = {}
for lang in langs:
for split in ("train", "test"):
fname = f"{lang}_{split}.csv"
path = os.path.join(base_dir, fname)
if not os.path.isfile(path):
raise FileNotFoundError(f"Expected dataset file missing: {path}")
df = pd.read_csv(path)
datasets[f"{lang}_{split}"] = df
return datasets
def parse_labels_column(df):
"""Parse the 'labels' column of a DataFrame into lists of integers."""
def _parse_one(x):
if isinstance(x, str):
s = x.strip()
if s.startswith("[") and s.endswith("]"):
s = s[1:-1]
return [int(tok) for tok in s.split() if tok]
try:
import numpy as np
if isinstance(x, np.ndarray):
return [int(v) for v in x.tolist()]
except ImportError:
pass
if isinstance(x, (list, tuple)):
return [int(v) for v in x]
raise ValueError(f"Formato labels non gestito: {type(x)} -> {x!r}")
df["labels"] = df["labels"].apply(_parse_one)
return df
|