File size: 3,930 Bytes
c6e5251
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from __future__ import annotations
from typing import Iterator, List, Optional
import os

import pandas as pd
import pyarrow.parquet as pq
from tqdm import tqdm

from . import config as CFG
from .clean import clean_text, filter_short
from .utils import get_logger

logger = get_logger()


def _get_parquet_columns(path: str) -> List[str]:
    pf = pq.ParquetFile(path)
    names: List[str]
    try:
        names = list(pf.schema.names)  # type: ignore[attr-defined]
    except Exception:
        try:
            names = list(pf.schema_arrow.names)  # type: ignore[attr-defined]
        except Exception:
            # Fallback: read a small handle to get column names
            tbl = pf.read_row_group(0)
            names = list(tbl.column_names)
    return names


def resolve_text_column(path: str, requested: Optional[str]) -> str:
    """Resolve the text column name to read from a Parquet file.



    Priority:

    1) Exact match of requested

    2) Case-insensitive match of requested

    3) First match from common candidates (Message, text, message, Text, body, content)

    4) Error with helpful message listing available columns

    """
    cols = _get_parquet_columns(path)
    if not cols:
        raise ValueError(f"No columns found in parquet file: {path}")

    if requested:
        if requested in cols:
            return requested
        lc_map = {c.lower(): c for c in cols}
        if requested.lower() in lc_map:
            resolved = lc_map[requested.lower()]
            logger.warning(f"Requested text column '{requested}' not found; using case-insensitive match '{resolved}'.")
            return resolved

    # Common candidates
    candidates = [
        "Message", "text", "message", "Text", "body", "content",
    ]
    for cand in candidates:
        if cand in cols:
            if requested and cand != requested:
                logger.warning(f"Requested text column '{requested}' not found; falling back to '{cand}'.")
            return cand

    # Last resort: pick the first string-like column by heuristic (object or large string names often include 'msg'/'chat')
    heuristics = ["msg", "chat", "comment"]
    for h in heuristics:
        for c in cols:
            if h in c.lower():
                logger.warning(f"Requested text column '{requested}' not found; heuristically using '{c}'.")
                return c

    raise KeyError(
        f"Text column '{requested}' not found. Available columns: {cols}. "
        f"Specify a valid name with --text_column or rename your data column."
    )


def read_parquet_stream(path: str, text_column: str | None = None, batch_size: int = 100_000) -> Iterator[List[str]]:
    """Yield cleaned text in batches from a Parquet file using pyarrow streaming.

    """
    if text_column is None:
        text_column = CFG.TEXT_COLUMN
    # Resolve against actual parquet columns (handles rename to 'Message' or other synonyms)
    resolved_col = resolve_text_column(path, text_column)
    pf = pq.ParquetFile(path)
    for batch in pf.iter_batches(batch_size=batch_size, columns=[resolved_col]):
        tbl = batch.to_pandas(types_mapper=pd.ArrowDtype)
        texts = tbl[resolved_col].astype(str).tolist()
        cleaned = [clean_text(t) for t in texts]
        cleaned = filter_short(cleaned, 1)
        yield cleaned


def train_val_split_stream(path: str, val_ratio: float = 0.05, text_column: str | None = None, batch_size: int = 100_000):
    """Stream batches and split into train and val lists per batch to avoid loading all in memory."""
    if text_column is None:
        text_column = CFG.TEXT_COLUMN
    for cleaned in read_parquet_stream(path, text_column, batch_size):
        n = len(cleaned)
        if n == 0:
            continue
        cut = int(n * (1 - val_ratio))
        yield cleaned[:cut], cleaned[cut:]