Twitch-BPE / src /io_data.py
Soldier-Boy's picture
create: src files
c6e5251 verified
from __future__ import annotations
from typing import Iterator, List, Optional
import os
import pandas as pd
import pyarrow.parquet as pq
from tqdm import tqdm
from . import config as CFG
from .clean import clean_text, filter_short
from .utils import get_logger
logger = get_logger()
def _get_parquet_columns(path: str) -> List[str]:
pf = pq.ParquetFile(path)
names: List[str]
try:
names = list(pf.schema.names) # type: ignore[attr-defined]
except Exception:
try:
names = list(pf.schema_arrow.names) # type: ignore[attr-defined]
except Exception:
# Fallback: read a small handle to get column names
tbl = pf.read_row_group(0)
names = list(tbl.column_names)
return names
def resolve_text_column(path: str, requested: Optional[str]) -> str:
"""Resolve the text column name to read from a Parquet file.
Priority:
1) Exact match of requested
2) Case-insensitive match of requested
3) First match from common candidates (Message, text, message, Text, body, content)
4) Error with helpful message listing available columns
"""
cols = _get_parquet_columns(path)
if not cols:
raise ValueError(f"No columns found in parquet file: {path}")
if requested:
if requested in cols:
return requested
lc_map = {c.lower(): c for c in cols}
if requested.lower() in lc_map:
resolved = lc_map[requested.lower()]
logger.warning(f"Requested text column '{requested}' not found; using case-insensitive match '{resolved}'.")
return resolved
# Common candidates
candidates = [
"Message", "text", "message", "Text", "body", "content",
]
for cand in candidates:
if cand in cols:
if requested and cand != requested:
logger.warning(f"Requested text column '{requested}' not found; falling back to '{cand}'.")
return cand
# Last resort: pick the first string-like column by heuristic (object or large string names often include 'msg'/'chat')
heuristics = ["msg", "chat", "comment"]
for h in heuristics:
for c in cols:
if h in c.lower():
logger.warning(f"Requested text column '{requested}' not found; heuristically using '{c}'.")
return c
raise KeyError(
f"Text column '{requested}' not found. Available columns: {cols}. "
f"Specify a valid name with --text_column or rename your data column."
)
def read_parquet_stream(path: str, text_column: str | None = None, batch_size: int = 100_000) -> Iterator[List[str]]:
"""Yield cleaned text in batches from a Parquet file using pyarrow streaming.
"""
if text_column is None:
text_column = CFG.TEXT_COLUMN
# Resolve against actual parquet columns (handles rename to 'Message' or other synonyms)
resolved_col = resolve_text_column(path, text_column)
pf = pq.ParquetFile(path)
for batch in pf.iter_batches(batch_size=batch_size, columns=[resolved_col]):
tbl = batch.to_pandas(types_mapper=pd.ArrowDtype)
texts = tbl[resolved_col].astype(str).tolist()
cleaned = [clean_text(t) for t in texts]
cleaned = filter_short(cleaned, 1)
yield cleaned
def train_val_split_stream(path: str, val_ratio: float = 0.05, text_column: str | None = None, batch_size: int = 100_000):
"""Stream batches and split into train and val lists per batch to avoid loading all in memory."""
if text_column is None:
text_column = CFG.TEXT_COLUMN
for cleaned in read_parquet_stream(path, text_column, batch_size):
n = len(cleaned)
if n == 0:
continue
cut = int(n * (1 - val_ratio))
yield cleaned[:cut], cleaned[cut:]