File size: 3,024 Bytes
198ccb0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
import argparse
from pathlib import Path
import pandas as pd
def prepare_tiny_dataset(
input_path: Path,
output_dir: Path,
n_total: int = 110,
n_train: int = 100,
) -> None:
"""
Create tiny train/val splits from a full labeled dataset.
Expected input columns (can be adapted as needed):
- id: unique identifier (will be created if missing)
- title: news title (Russian)
- snippet: optional text snippet (can be empty)
- tags: comma-separated labels, e.g. "политика,экономика"
"""
output_dir.mkdir(parents=True, exist_ok=True)
# Detect separator based on file extension (CSV vs TSV)
if input_path.suffix.lower() == ".tsv":
df = pd.read_csv(input_path, sep="\t")
else:
df = pd.read_csv(input_path)
# Ensure required columns exist or can be constructed
if "title" not in df.columns:
raise ValueError("Input dataset must contain a 'title' column.")
if "tags" not in df.columns:
raise ValueError("Input dataset must contain a 'tags' column (comma-separated labels).")
if "snippet" not in df.columns:
df["snippet"] = ""
if "id" not in df.columns:
df["id"] = range(1, len(df) + 1)
# Take the first n_total rows (or fewer if dataset is small)
df_tiny = df.head(n_total).copy()
tiny_all_path = output_dir / "tiny_all.csv"
df_tiny.to_csv(tiny_all_path, index=False)
# Split into train / val
df_train = df_tiny.head(n_train).copy()
df_val = df_tiny.tail(len(df_tiny) - len(df_train)).copy()
tiny_train_path = output_dir / "tiny_train.csv"
tiny_val_path = output_dir / "tiny_val.csv"
df_train.to_csv(tiny_train_path, index=False)
df_val.to_csv(tiny_val_path, index=False)
print(f"Saved tiny_all to: {tiny_all_path}")
print(f"Saved tiny_train (n={len(df_train)}) to: {tiny_train_path}")
print(f"Saved tiny_val (n={len(df_val)}) to: {tiny_val_path}")
def main() -> None:
parser = argparse.ArgumentParser(description="Prepare tiny train/val splits for quick finetuning.")
parser.add_argument(
"--input",
type=str,
required=True,
help="Path to the full labeled dataset CSV.",
)
parser.add_argument(
"--output-dir",
type=str,
default="data",
help="Directory to save tiny_all.csv, tiny_train.csv, tiny_val.csv (default: data).",
)
parser.add_argument(
"--n-total",
type=int,
default=110,
help="Total number of samples to keep for tiny dataset (default: 110).",
)
parser.add_argument(
"--n-train",
type=int,
default=100,
help="Number of samples to use for training (default: 100, rest go to val).",
)
args = parser.parse_args()
prepare_tiny_dataset(
input_path=Path(args.input),
output_dir=Path(args.output_dir),
n_total=args.n_total,
n_train=args.n_train,
)
if __name__ == "__main__":
main()
|