Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| from typing import Iterator, List, Optional | |
| import os | |
| import pandas as pd | |
| import pyarrow.parquet as pq | |
| from tqdm import tqdm | |
| from . import config as CFG | |
| from .clean import clean_text, filter_short | |
| from .utils import get_logger | |
| logger = get_logger() | |
| def _get_parquet_columns(path: str) -> List[str]: | |
| pf = pq.ParquetFile(path) | |
| names: List[str] | |
| try: | |
| names = list(pf.schema.names) # type: ignore[attr-defined] | |
| except Exception: | |
| try: | |
| names = list(pf.schema_arrow.names) # type: ignore[attr-defined] | |
| except Exception: | |
| # Fallback: read a small handle to get column names | |
| tbl = pf.read_row_group(0) | |
| names = list(tbl.column_names) | |
| return names | |
| def resolve_text_column(path: str, requested: Optional[str]) -> str: | |
| """Resolve the text column name to read from a Parquet file. | |
| Priority: | |
| 1) Exact match of requested | |
| 2) Case-insensitive match of requested | |
| 3) First match from common candidates (Message, text, message, Text, body, content) | |
| 4) Error with helpful message listing available columns | |
| """ | |
| cols = _get_parquet_columns(path) | |
| if not cols: | |
| raise ValueError(f"No columns found in parquet file: {path}") | |
| if requested: | |
| if requested in cols: | |
| return requested | |
| lc_map = {c.lower(): c for c in cols} | |
| if requested.lower() in lc_map: | |
| resolved = lc_map[requested.lower()] | |
| logger.warning(f"Requested text column '{requested}' not found; using case-insensitive match '{resolved}'.") | |
| return resolved | |
| # Common candidates | |
| candidates = [ | |
| "Message", "text", "message", "Text", "body", "content", | |
| ] | |
| for cand in candidates: | |
| if cand in cols: | |
| if requested and cand != requested: | |
| logger.warning(f"Requested text column '{requested}' not found; falling back to '{cand}'.") | |
| return cand | |
| # Last resort: pick the first string-like column by heuristic (object or large string names often include 'msg'/'chat') | |
| heuristics = ["msg", "chat", "comment"] | |
| for h in heuristics: | |
| for c in cols: | |
| if h in c.lower(): | |
| logger.warning(f"Requested text column '{requested}' not found; heuristically using '{c}'.") | |
| return c | |
| raise KeyError( | |
| f"Text column '{requested}' not found. Available columns: {cols}. " | |
| f"Specify a valid name with --text_column or rename your data column." | |
| ) | |
| def read_parquet_stream(path: str, text_column: str | None = None, batch_size: int = 100_000) -> Iterator[List[str]]: | |
| """Yield cleaned text in batches from a Parquet file using pyarrow streaming. | |
| """ | |
| if text_column is None: | |
| text_column = CFG.TEXT_COLUMN | |
| # Resolve against actual parquet columns (handles rename to 'Message' or other synonyms) | |
| resolved_col = resolve_text_column(path, text_column) | |
| pf = pq.ParquetFile(path) | |
| for batch in pf.iter_batches(batch_size=batch_size, columns=[resolved_col]): | |
| tbl = batch.to_pandas(types_mapper=pd.ArrowDtype) | |
| texts = tbl[resolved_col].astype(str).tolist() | |
| cleaned = [clean_text(t) for t in texts] | |
| cleaned = filter_short(cleaned, 1) | |
| yield cleaned | |
| def train_val_split_stream(path: str, val_ratio: float = 0.05, text_column: str | None = None, batch_size: int = 100_000): | |
| """Stream batches and split into train and val lists per batch to avoid loading all in memory.""" | |
| if text_column is None: | |
| text_column = CFG.TEXT_COLUMN | |
| for cleaned in read_parquet_stream(path, text_column, batch_size): | |
| n = len(cleaned) | |
| if n == 0: | |
| continue | |
| cut = int(n * (1 - val_ratio)) | |
| yield cleaned[:cut], cleaned[cut:] | |