from __future__ import annotations from typing import Iterator, List, Optional import os import pandas as pd import pyarrow.parquet as pq from tqdm import tqdm from . import config as CFG from .clean import clean_text, filter_short from .utils import get_logger logger = get_logger() def _get_parquet_columns(path: str) -> List[str]: pf = pq.ParquetFile(path) names: List[str] try: names = list(pf.schema.names) # type: ignore[attr-defined] except Exception: try: names = list(pf.schema_arrow.names) # type: ignore[attr-defined] except Exception: # Fallback: read a small handle to get column names tbl = pf.read_row_group(0) names = list(tbl.column_names) return names def resolve_text_column(path: str, requested: Optional[str]) -> str: """Resolve the text column name to read from a Parquet file. Priority: 1) Exact match of requested 2) Case-insensitive match of requested 3) First match from common candidates (Message, text, message, Text, body, content) 4) Error with helpful message listing available columns """ cols = _get_parquet_columns(path) if not cols: raise ValueError(f"No columns found in parquet file: {path}") if requested: if requested in cols: return requested lc_map = {c.lower(): c for c in cols} if requested.lower() in lc_map: resolved = lc_map[requested.lower()] logger.warning(f"Requested text column '{requested}' not found; using case-insensitive match '{resolved}'.") return resolved # Common candidates candidates = [ "Message", "text", "message", "Text", "body", "content", ] for cand in candidates: if cand in cols: if requested and cand != requested: logger.warning(f"Requested text column '{requested}' not found; falling back to '{cand}'.") return cand # Last resort: pick the first string-like column by heuristic (object or large string names often include 'msg'/'chat') heuristics = ["msg", "chat", "comment"] for h in heuristics: for c in cols: if h in c.lower(): logger.warning(f"Requested text column '{requested}' not found; heuristically using '{c}'.") return c raise KeyError( f"Text column '{requested}' not found. Available columns: {cols}. " f"Specify a valid name with --text_column or rename your data column." ) def read_parquet_stream(path: str, text_column: str | None = None, batch_size: int = 100_000) -> Iterator[List[str]]: """Yield cleaned text in batches from a Parquet file using pyarrow streaming. """ if text_column is None: text_column = CFG.TEXT_COLUMN # Resolve against actual parquet columns (handles rename to 'Message' or other synonyms) resolved_col = resolve_text_column(path, text_column) pf = pq.ParquetFile(path) for batch in pf.iter_batches(batch_size=batch_size, columns=[resolved_col]): tbl = batch.to_pandas(types_mapper=pd.ArrowDtype) texts = tbl[resolved_col].astype(str).tolist() cleaned = [clean_text(t) for t in texts] cleaned = filter_short(cleaned, 1) yield cleaned def train_val_split_stream(path: str, val_ratio: float = 0.05, text_column: str | None = None, batch_size: int = 100_000): """Stream batches and split into train and val lists per batch to avoid loading all in memory.""" if text_column is None: text_column = CFG.TEXT_COLUMN for cleaned in read_parquet_stream(path, text_column, batch_size): n = len(cleaned) if n == 0: continue cut = int(n * (1 - val_ratio)) yield cleaned[:cut], cleaned[cut:]