File size: 15,042 Bytes
e020674
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
from abc import ABC, abstractmethod
from dataflow import get_logger
import pandas as pd
import json
from typing import Any, Literal
import os


class DataFlowStorage(ABC):
    """
    Abstract base class for data storage.
    """
    @abstractmethod
    def read(self, output_type) -> Any:
        """
        Read data from file.
        type: type that you want to read to, such as "datatrame", List[dict], etc.
        """
        pass
    
    @abstractmethod
    def write(self, data: Any) -> Any:
        pass

class FileStorage(DataFlowStorage):
    """
    Storage for file system.
    """
    def __init__(self, 
                 first_entry_file_name: str,
                 cache_path:str="./cache",
                 file_name_prefix:str="dataflow_cache_step",
                 cache_type:Literal["json", "jsonl", "csv", "parquet", "pickle"] = "jsonl"
                 ):
        self.first_entry_file_name = first_entry_file_name
        self.cache_path = cache_path
        self.file_name_prefix = file_name_prefix
        self.cache_type = cache_type
        self.operator_step = -1
        self.logger = get_logger()

    def _get_cache_file_path(self, step) -> str:
        if step == -1:
            self.logger.error("You must call storage.step() before reading or writing data. Please call storage.step() first for each operator step.")  
            raise ValueError("You must call storage.step() before reading or writing data. Please call storage.step() first for each operator step.")
        if step == 0:
            # If it's the first step, use the first entry file name
            return os.path.join(self.first_entry_file_name)
        else:
            return os.path.join(self.cache_path, f"{self.file_name_prefix}_{step}.{self.cache_type}")

    def step(self):
        self.operator_step += 1
        return self
    
    def reset(self):
        self.operator_step = -1
        return self
    
    def _load_local_file(self, file_path: str, file_type: str) -> pd.DataFrame:
        """Load data from local file based on file type."""
        # check if file exists
        if not os.path.exists(file_path):
            raise FileNotFoundError(f"File {file_path} does not exist. Please check the path.")
        # Load file based on type
        try:
            if file_type == "json":
                return pd.read_json(file_path)
            elif file_type == "jsonl":
                return pd.read_json(file_path, lines=True)
            elif file_type == "csv":
                return pd.read_csv(file_path)
            elif file_type == "parquet":
                return pd.read_parquet(file_path)
            elif file_type == "pickle":
                return pd.read_pickle(file_path)
            else:
                raise ValueError(f"Unsupported file type: {file_type}")
        except Exception as e:
            raise ValueError(f"Failed to load {file_type} file: {str(e)}")
    
    def _convert_output(self, dataframe: pd.DataFrame, output_type: str) -> Any:
        """Convert dataframe to requested output type."""
        if output_type == "dataframe":
            return dataframe
        elif output_type == "dict":
            return dataframe.to_dict(orient="records")
        raise ValueError(f"Unsupported output type: {output_type}")

    def read(self, output_type: Literal["dataframe", "dict"]) -> Any:
        """
        Read data from current file managed by storage.
        
        Args:
            output_type: Type that you want to read to, either "dataframe" or "dict".
            Also supports remote datasets with prefix:
                - "hf:{dataset_name}{:config}{:split}"  => HuggingFace dataset eg. "hf:openai/gsm8k:main:train"
                - "ms:{dataset_name}{}:split}"          => ModelScope dataset eg. "ms:modelscope/gsm8k:train"
        
        Returns:
            Depending on output_type:
            - "dataframe": pandas DataFrame
            - "dict": List of dictionaries
        
        Raises:
            ValueError: For unsupported file types or output types
        """
        file_path = self._get_cache_file_path(self.operator_step)
        self.logger.info(f"Reading data from {file_path} with type {output_type}")

        if self.operator_step == 0:
            source = self.first_entry_file_name
            self.logger.info(f"Reading remote dataset from {source} with type {output_type}")
            if source.startswith("hf:"):
                from datasets import load_dataset
                _, dataset_name, *parts = source.split(":")

                if len(parts) == 1:
                    config, split = None, parts[0]
                elif len(parts) == 2:
                    config, split = parts
                else:
                    config, split = None, "train"

                dataset = (
                    load_dataset(dataset_name, config, split=split) 
                    if config 
                    else load_dataset(dataset_name, split=split)
                )
                dataframe = dataset.to_pandas()
                return self._convert_output(dataframe, output_type)
        
            elif source.startswith("ms:"):
                from modelscope import MsDataset
                _, dataset_name, *split_parts = source.split(":")
                split = split_parts[0] if split_parts else "train"

                dataset = MsDataset.load(dataset_name, split=split)
                dataframe = pd.DataFrame(dataset)
                return self._convert_output(dataframe, output_type)
                            
            else:
                local_cache = file_path.split(".")[-1]
        else:
            local_cache = self.cache_type

        dataframe = self._load_local_file(file_path, local_cache)
        return self._convert_output(dataframe, output_type)
        
    def write(self, data: Any) -> Any:
        """
        Write data to current file managed by storage.
        data: Any, the data to write, it should be a dataframe, List[dict], etc.
        """
        if type(data) == list:
            if type(data[0]) == dict:
                dataframe = pd.DataFrame(data)
            else:
                raise ValueError(f"Unsupported data type: {type(data[0])}")
        elif type(data) == pd.DataFrame:
            dataframe = data
        else:
            raise ValueError(f"Unsupported data type: {type(data)}")

        file_path = self._get_cache_file_path(self.operator_step + 1)
        os.makedirs(os.path.dirname(file_path), exist_ok=True)
        self.logger.success(f"Writing data to {file_path} with type {self.cache_type}")
        if self.cache_type == "json":
            dataframe.to_json(file_path, orient="records", force_ascii=False, indent=2)
        elif self.cache_type == "jsonl":
            dataframe.to_json(file_path, orient="records", lines=True, force_ascii=False)
        elif self.cache_type == "csv":
            dataframe.to_csv(file_path, index=False)
        elif self.cache_type == "parquet":
            dataframe.to_parquet(file_path)
        elif self.cache_type == "pickle":
            dataframe.to_pickle(file_path)
        else:
            raise ValueError(f"Unsupported file type: {self.cache_type}, output file should end with json, jsonl, csv, parquet, pickle")
        
        return file_path

from threading import Lock

_clickhouse_clients = {}
_clickhouse_clients_lock = Lock()

# 预定义字段前缀
SYS_FIELD_PREFIX = 'sys:'
USER_FIELD_PREFIX = 'user:'

# 获取ClickHouse Client单例
def get_clickhouse_client(db_config):
    key = (
        db_config['host'],
        db_config.get('port', 9000),
        db_config.get('user', 'default'),
        db_config.get('database', 'dataflow'),
    )
    with _clickhouse_clients_lock:
        if key not in _clickhouse_clients:
            try:
                from clickhouse_driver import Client
            except ImportError as e:
                raise ImportError("clickhouse_driver is required for MyScaleDBStorage but not installed. Please install it via 'pip install clickhouse-driver'.") from e
            _clickhouse_clients[key] = Client(
                host=db_config['host'],
                port=db_config.get('port', 9000),
                user=db_config.get('user', 'default'),
                password=db_config.get('password', ''),
                database=db_config.get('database', 'default'),
                settings={"use_numpy": True}
            )
        return _clickhouse_clients[key]

# 安全加载json数据
def safe_json_loads(x):
        if isinstance(x, dict):
            return x
        if isinstance(x, str):
            try:
                return json.loads(x)
            except Exception:
                return x  # 保留原始字符串
        if pd.isna(x):
            return None
        return x  # 其它类型原样返回

# 预定义min_hashes计算方法,当前全部返回[0]
def _default_min_hashes(self, data_dict):
    return [0]

class MyScaleDBStorage(DataFlowStorage):
    """
    Storage for Myscale/ClickHouse database using clickhouse_driver.
    """
    def validate_required_params(self):
        """
        校验MyScaleDBStorage实例的关键参数有效性:
        - pipeline_id, input_task_id, output_task_id 必须非空,否则抛出异常。
        - page_size, page_num 若未设置则赋默认值(page_size=10000, page_num=0)。
        所有算子在使用storage前应调用本方法。
        """
        missing = []
        if not self.pipeline_id:
            missing.append('pipeline_id')
        if not self.input_task_id:
            missing.append('input_task_id')
        if not self.output_task_id:
            missing.append('output_task_id')
        if missing:
            raise ValueError(f"Missing required storage parameters: {', '.join(missing)}")
        if not hasattr(self, 'page_size') or self.page_size is None:
            self.page_size = 10000
        if not hasattr(self, 'page_num') or self.page_num is None:
            self.page_num = 0

    def __init__(
        self,
        db_config: dict,
        pipeline_id: str = None,
        input_task_id: str = None,
        output_task_id: str = None,
        page_size: int = 10000,
        page_num: int = 0
    ):
        """
        db_config: {
            'host': 'localhost',
            'port': 9000,
            'user': 'default',
            'password': '',
            'database': 'dataflow',
            'table': 'dataflow_table'
        }
        pipeline_id: str, 当前 pipeline 的标识(可选,默认 None)
        input_task_id: str, 输入任务的标识(可选,默认 None)
        output_task_id: str, 输出任务的标识(可选,默认 None)
        page_size: int, 分页时每页的记录数(默认 10000)
        page_num: int, 当前页码(默认 0)
        """
        self.db_config = db_config
        self.client = get_clickhouse_client(db_config)
        self.table = db_config.get('table', 'dataflow_table')
        self.logger = get_logger()
        self.pipeline_id: str = pipeline_id
        self.input_task_id: str = input_task_id
        self.output_task_id: str = output_task_id
        self.page_size: int = page_size
        self.page_num: int = page_num
        self.validate_required_params()

    def read(self, output_type: Literal["dataframe", "dict"]) -> Any:
        """
        Read data from Myscale/ClickHouse table.
        """
        where_clauses = []
        params = {}
        if self.pipeline_id:
            where_clauses.append("pipeline_id = %(pipeline_id)s")
            params['pipeline_id'] = self.pipeline_id
        if self.input_task_id:
            where_clauses.append("task_id = %(task_id)s")
            params['task_id'] = self.input_task_id
        where_sql = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
        limit_offset = f"LIMIT {self.page_size} OFFSET {(self.page_num-1)*self.page_size}" if self.page_size else ""
        sql = f"SELECT * FROM {self.table} {where_sql} {limit_offset}"
        self.logger.info(f"Reading from DB: {sql} with params {params}")
        result = self.client.execute(sql, params, with_column_types=True)
        rows, col_types = result
        columns = [col[0] for col in col_types]
        df = pd.DataFrame(rows, columns=columns)
        # 解析 data 字段为 dict
        if 'data' not in df.columns:
            raise ValueError("Result does not contain required 'data' field.")

        # 只保留 data 字段
        data_series = df['data'].apply(safe_json_loads)

        if output_type == "dataframe":
            # 返回只有 data 一列的 DataFrame
            return pd.DataFrame({'data': data_series})
        elif output_type == "dict":
            # 返回 data 字段的 dict 列表
            return list(data_series)
        else:
            raise ValueError(f"Unsupported output type: {output_type}")

    def write(self, data: Any) -> Any:
        """
        Write data to Myscale/ClickHouse table.
        data: pd.DataFrame or List[dict],每行是data字段内容(dict)。
        """
        if isinstance(data, list):
            df = pd.DataFrame(data)
        elif isinstance(data, pd.DataFrame):
            df = data
        else:
            raise ValueError(f"Unsupported data type: {type(data)}")
        # data字段本身就是每行的内容
        if 'data' not in df.columns:
            # 兼容直接传入dict列表的情况
            df['data'] = df.apply(lambda row: row.to_dict(), axis=1)
        # 统一处理data列
        df['data'] = df['data'].apply(lambda x: x if isinstance(x, dict) else (json.loads(x) if isinstance(x, str) else {}))
        # 自动填充pipeline_id, task_id, raw_data_id, min_hashes
        df['pipeline_id'] = self.pipeline_id
        df['task_id'] = self.output_task_id
        df['raw_data_id'] = df['data'].apply(lambda d: d.get(SYS_FIELD_PREFIX + 'raw_data_id', 0) if isinstance(d, dict) else 0)
        df['min_hashes'] = df['data'].apply(lambda d: _default_min_hashes(d) if isinstance(d, dict) else [0])
        # data字段转为JSON字符串
        df['data'] = df['data'].apply(lambda x: json.dumps(x, ensure_ascii=False) if not isinstance(x, str) else x)
        # 只保留必需字段
        required_cols = ['pipeline_id', 'task_id', 'raw_data_id', 'min_hashes', 'data']
        df = df[required_cols]
        records = df.to_dict(orient="records")
        values = [
            (
                rec['pipeline_id'],
                rec['task_id'],
                int(rec['raw_data_id']),
                rec['min_hashes'],
                rec['data']
            ) for rec in records
        ]
        insert_sql = f"""
        INSERT INTO {self.table} (pipeline_id, task_id, raw_data_id, min_hashes, data)
        VALUES
        """
        self.logger.info(f"Inserting {len(values)} rows into {self.table}")
        self.client.execute(insert_sql, values)
        return f"Inserted {len(values)} rows into {self.table}"