File size: 15,042 Bytes
e020674 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 |
from abc import ABC, abstractmethod
from dataflow import get_logger
import pandas as pd
import json
from typing import Any, Literal
import os
class DataFlowStorage(ABC):
"""
Abstract base class for data storage.
"""
@abstractmethod
def read(self, output_type) -> Any:
"""
Read data from file.
type: type that you want to read to, such as "datatrame", List[dict], etc.
"""
pass
@abstractmethod
def write(self, data: Any) -> Any:
pass
class FileStorage(DataFlowStorage):
"""
Storage for file system.
"""
def __init__(self,
first_entry_file_name: str,
cache_path:str="./cache",
file_name_prefix:str="dataflow_cache_step",
cache_type:Literal["json", "jsonl", "csv", "parquet", "pickle"] = "jsonl"
):
self.first_entry_file_name = first_entry_file_name
self.cache_path = cache_path
self.file_name_prefix = file_name_prefix
self.cache_type = cache_type
self.operator_step = -1
self.logger = get_logger()
def _get_cache_file_path(self, step) -> str:
if step == -1:
self.logger.error("You must call storage.step() before reading or writing data. Please call storage.step() first for each operator step.")
raise ValueError("You must call storage.step() before reading or writing data. Please call storage.step() first for each operator step.")
if step == 0:
# If it's the first step, use the first entry file name
return os.path.join(self.first_entry_file_name)
else:
return os.path.join(self.cache_path, f"{self.file_name_prefix}_{step}.{self.cache_type}")
def step(self):
self.operator_step += 1
return self
def reset(self):
self.operator_step = -1
return self
def _load_local_file(self, file_path: str, file_type: str) -> pd.DataFrame:
"""Load data from local file based on file type."""
# check if file exists
if not os.path.exists(file_path):
raise FileNotFoundError(f"File {file_path} does not exist. Please check the path.")
# Load file based on type
try:
if file_type == "json":
return pd.read_json(file_path)
elif file_type == "jsonl":
return pd.read_json(file_path, lines=True)
elif file_type == "csv":
return pd.read_csv(file_path)
elif file_type == "parquet":
return pd.read_parquet(file_path)
elif file_type == "pickle":
return pd.read_pickle(file_path)
else:
raise ValueError(f"Unsupported file type: {file_type}")
except Exception as e:
raise ValueError(f"Failed to load {file_type} file: {str(e)}")
def _convert_output(self, dataframe: pd.DataFrame, output_type: str) -> Any:
"""Convert dataframe to requested output type."""
if output_type == "dataframe":
return dataframe
elif output_type == "dict":
return dataframe.to_dict(orient="records")
raise ValueError(f"Unsupported output type: {output_type}")
def read(self, output_type: Literal["dataframe", "dict"]) -> Any:
"""
Read data from current file managed by storage.
Args:
output_type: Type that you want to read to, either "dataframe" or "dict".
Also supports remote datasets with prefix:
- "hf:{dataset_name}{:config}{:split}" => HuggingFace dataset eg. "hf:openai/gsm8k:main:train"
- "ms:{dataset_name}{}:split}" => ModelScope dataset eg. "ms:modelscope/gsm8k:train"
Returns:
Depending on output_type:
- "dataframe": pandas DataFrame
- "dict": List of dictionaries
Raises:
ValueError: For unsupported file types or output types
"""
file_path = self._get_cache_file_path(self.operator_step)
self.logger.info(f"Reading data from {file_path} with type {output_type}")
if self.operator_step == 0:
source = self.first_entry_file_name
self.logger.info(f"Reading remote dataset from {source} with type {output_type}")
if source.startswith("hf:"):
from datasets import load_dataset
_, dataset_name, *parts = source.split(":")
if len(parts) == 1:
config, split = None, parts[0]
elif len(parts) == 2:
config, split = parts
else:
config, split = None, "train"
dataset = (
load_dataset(dataset_name, config, split=split)
if config
else load_dataset(dataset_name, split=split)
)
dataframe = dataset.to_pandas()
return self._convert_output(dataframe, output_type)
elif source.startswith("ms:"):
from modelscope import MsDataset
_, dataset_name, *split_parts = source.split(":")
split = split_parts[0] if split_parts else "train"
dataset = MsDataset.load(dataset_name, split=split)
dataframe = pd.DataFrame(dataset)
return self._convert_output(dataframe, output_type)
else:
local_cache = file_path.split(".")[-1]
else:
local_cache = self.cache_type
dataframe = self._load_local_file(file_path, local_cache)
return self._convert_output(dataframe, output_type)
def write(self, data: Any) -> Any:
"""
Write data to current file managed by storage.
data: Any, the data to write, it should be a dataframe, List[dict], etc.
"""
if type(data) == list:
if type(data[0]) == dict:
dataframe = pd.DataFrame(data)
else:
raise ValueError(f"Unsupported data type: {type(data[0])}")
elif type(data) == pd.DataFrame:
dataframe = data
else:
raise ValueError(f"Unsupported data type: {type(data)}")
file_path = self._get_cache_file_path(self.operator_step + 1)
os.makedirs(os.path.dirname(file_path), exist_ok=True)
self.logger.success(f"Writing data to {file_path} with type {self.cache_type}")
if self.cache_type == "json":
dataframe.to_json(file_path, orient="records", force_ascii=False, indent=2)
elif self.cache_type == "jsonl":
dataframe.to_json(file_path, orient="records", lines=True, force_ascii=False)
elif self.cache_type == "csv":
dataframe.to_csv(file_path, index=False)
elif self.cache_type == "parquet":
dataframe.to_parquet(file_path)
elif self.cache_type == "pickle":
dataframe.to_pickle(file_path)
else:
raise ValueError(f"Unsupported file type: {self.cache_type}, output file should end with json, jsonl, csv, parquet, pickle")
return file_path
from threading import Lock
_clickhouse_clients = {}
_clickhouse_clients_lock = Lock()
# 预定义字段前缀
SYS_FIELD_PREFIX = 'sys:'
USER_FIELD_PREFIX = 'user:'
# 获取ClickHouse Client单例
def get_clickhouse_client(db_config):
key = (
db_config['host'],
db_config.get('port', 9000),
db_config.get('user', 'default'),
db_config.get('database', 'dataflow'),
)
with _clickhouse_clients_lock:
if key not in _clickhouse_clients:
try:
from clickhouse_driver import Client
except ImportError as e:
raise ImportError("clickhouse_driver is required for MyScaleDBStorage but not installed. Please install it via 'pip install clickhouse-driver'.") from e
_clickhouse_clients[key] = Client(
host=db_config['host'],
port=db_config.get('port', 9000),
user=db_config.get('user', 'default'),
password=db_config.get('password', ''),
database=db_config.get('database', 'default'),
settings={"use_numpy": True}
)
return _clickhouse_clients[key]
# 安全加载json数据
def safe_json_loads(x):
if isinstance(x, dict):
return x
if isinstance(x, str):
try:
return json.loads(x)
except Exception:
return x # 保留原始字符串
if pd.isna(x):
return None
return x # 其它类型原样返回
# 预定义min_hashes计算方法,当前全部返回[0]
def _default_min_hashes(self, data_dict):
return [0]
class MyScaleDBStorage(DataFlowStorage):
"""
Storage for Myscale/ClickHouse database using clickhouse_driver.
"""
def validate_required_params(self):
"""
校验MyScaleDBStorage实例的关键参数有效性:
- pipeline_id, input_task_id, output_task_id 必须非空,否则抛出异常。
- page_size, page_num 若未设置则赋默认值(page_size=10000, page_num=0)。
所有算子在使用storage前应调用本方法。
"""
missing = []
if not self.pipeline_id:
missing.append('pipeline_id')
if not self.input_task_id:
missing.append('input_task_id')
if not self.output_task_id:
missing.append('output_task_id')
if missing:
raise ValueError(f"Missing required storage parameters: {', '.join(missing)}")
if not hasattr(self, 'page_size') or self.page_size is None:
self.page_size = 10000
if not hasattr(self, 'page_num') or self.page_num is None:
self.page_num = 0
def __init__(
self,
db_config: dict,
pipeline_id: str = None,
input_task_id: str = None,
output_task_id: str = None,
page_size: int = 10000,
page_num: int = 0
):
"""
db_config: {
'host': 'localhost',
'port': 9000,
'user': 'default',
'password': '',
'database': 'dataflow',
'table': 'dataflow_table'
}
pipeline_id: str, 当前 pipeline 的标识(可选,默认 None)
input_task_id: str, 输入任务的标识(可选,默认 None)
output_task_id: str, 输出任务的标识(可选,默认 None)
page_size: int, 分页时每页的记录数(默认 10000)
page_num: int, 当前页码(默认 0)
"""
self.db_config = db_config
self.client = get_clickhouse_client(db_config)
self.table = db_config.get('table', 'dataflow_table')
self.logger = get_logger()
self.pipeline_id: str = pipeline_id
self.input_task_id: str = input_task_id
self.output_task_id: str = output_task_id
self.page_size: int = page_size
self.page_num: int = page_num
self.validate_required_params()
def read(self, output_type: Literal["dataframe", "dict"]) -> Any:
"""
Read data from Myscale/ClickHouse table.
"""
where_clauses = []
params = {}
if self.pipeline_id:
where_clauses.append("pipeline_id = %(pipeline_id)s")
params['pipeline_id'] = self.pipeline_id
if self.input_task_id:
where_clauses.append("task_id = %(task_id)s")
params['task_id'] = self.input_task_id
where_sql = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
limit_offset = f"LIMIT {self.page_size} OFFSET {(self.page_num-1)*self.page_size}" if self.page_size else ""
sql = f"SELECT * FROM {self.table} {where_sql} {limit_offset}"
self.logger.info(f"Reading from DB: {sql} with params {params}")
result = self.client.execute(sql, params, with_column_types=True)
rows, col_types = result
columns = [col[0] for col in col_types]
df = pd.DataFrame(rows, columns=columns)
# 解析 data 字段为 dict
if 'data' not in df.columns:
raise ValueError("Result does not contain required 'data' field.")
# 只保留 data 字段
data_series = df['data'].apply(safe_json_loads)
if output_type == "dataframe":
# 返回只有 data 一列的 DataFrame
return pd.DataFrame({'data': data_series})
elif output_type == "dict":
# 返回 data 字段的 dict 列表
return list(data_series)
else:
raise ValueError(f"Unsupported output type: {output_type}")
def write(self, data: Any) -> Any:
"""
Write data to Myscale/ClickHouse table.
data: pd.DataFrame or List[dict],每行是data字段内容(dict)。
"""
if isinstance(data, list):
df = pd.DataFrame(data)
elif isinstance(data, pd.DataFrame):
df = data
else:
raise ValueError(f"Unsupported data type: {type(data)}")
# data字段本身就是每行的内容
if 'data' not in df.columns:
# 兼容直接传入dict列表的情况
df['data'] = df.apply(lambda row: row.to_dict(), axis=1)
# 统一处理data列
df['data'] = df['data'].apply(lambda x: x if isinstance(x, dict) else (json.loads(x) if isinstance(x, str) else {}))
# 自动填充pipeline_id, task_id, raw_data_id, min_hashes
df['pipeline_id'] = self.pipeline_id
df['task_id'] = self.output_task_id
df['raw_data_id'] = df['data'].apply(lambda d: d.get(SYS_FIELD_PREFIX + 'raw_data_id', 0) if isinstance(d, dict) else 0)
df['min_hashes'] = df['data'].apply(lambda d: _default_min_hashes(d) if isinstance(d, dict) else [0])
# data字段转为JSON字符串
df['data'] = df['data'].apply(lambda x: json.dumps(x, ensure_ascii=False) if not isinstance(x, str) else x)
# 只保留必需字段
required_cols = ['pipeline_id', 'task_id', 'raw_data_id', 'min_hashes', 'data']
df = df[required_cols]
records = df.to_dict(orient="records")
values = [
(
rec['pipeline_id'],
rec['task_id'],
int(rec['raw_data_id']),
rec['min_hashes'],
rec['data']
) for rec in records
]
insert_sql = f"""
INSERT INTO {self.table} (pipeline_id, task_id, raw_data_id, min_hashes, data)
VALUES
"""
self.logger.info(f"Inserting {len(values)} rows into {self.table}")
self.client.execute(insert_sql, values)
return f"Inserted {len(values)} rows into {self.table}" |