| import sys |
| from dataclasses import dataclass |
| from typing import TYPE_CHECKING, Optional, Union |
|
|
| import pandas as pd |
| import pyarrow as pa |
|
|
| import datasets |
| import datasets.config |
| from datasets.features.features import require_storage_cast |
| from datasets.table import table_cast |
|
|
|
|
| if TYPE_CHECKING: |
| import sqlite3 |
|
|
| import sqlalchemy |
|
|
|
|
| logger = datasets.utils.logging.get_logger(__name__) |
|
|
|
|
| @dataclass |
| class SqlConfig(datasets.BuilderConfig): |
| """BuilderConfig for SQL.""" |
|
|
| sql: Union[str, "sqlalchemy.sql.Selectable"] = None |
| con: Union[str, "sqlalchemy.engine.Connection", "sqlalchemy.engine.Engine", "sqlite3.Connection"] = None |
| index_col: Optional[Union[str, list[str]]] = None |
| coerce_float: bool = True |
| params: Optional[Union[list, tuple, dict]] = None |
| parse_dates: Optional[Union[list, dict]] = None |
| columns: Optional[list[str]] = None |
| chunksize: Optional[int] = 10_000 |
| features: Optional[datasets.Features] = None |
|
|
| def __post_init__(self): |
| super().__post_init__() |
| if self.sql is None: |
| raise ValueError("sql must be specified") |
| if self.con is None: |
| raise ValueError("con must be specified") |
|
|
| def create_config_id( |
| self, |
| config_kwargs: dict, |
| custom_features: Optional[datasets.Features] = None, |
| ) -> str: |
| config_kwargs = config_kwargs.copy() |
| |
|
|
| |
| sql = config_kwargs["sql"] |
| if not isinstance(sql, str): |
| if datasets.config.SQLALCHEMY_AVAILABLE and "sqlalchemy" in sys.modules: |
| import sqlalchemy |
|
|
| if isinstance(sql, sqlalchemy.sql.Selectable): |
| engine = sqlalchemy.create_engine(config_kwargs["con"].split("://")[0] + "://") |
| sql_str = str(sql.compile(dialect=engine.dialect)) |
| config_kwargs["sql"] = sql_str |
| else: |
| raise TypeError( |
| f"Supported types for 'sql' are string and sqlalchemy.sql.Selectable but got {type(sql)}: {sql}" |
| ) |
| else: |
| raise TypeError( |
| f"Supported types for 'sql' are string and sqlalchemy.sql.Selectable but got {type(sql)}: {sql}" |
| ) |
| con = config_kwargs["con"] |
| if not isinstance(con, str): |
| config_kwargs["con"] = id(con) |
| logger.info( |
| f"SQL connection 'con' of type {type(con)} couldn't be hashed properly. To enable hashing, specify 'con' as URI string instead." |
| ) |
|
|
| return super().create_config_id(config_kwargs, custom_features=custom_features) |
|
|
| @property |
| def pd_read_sql_kwargs(self): |
| pd_read_sql_kwargs = { |
| "index_col": self.index_col, |
| "columns": self.columns, |
| "params": self.params, |
| "coerce_float": self.coerce_float, |
| "parse_dates": self.parse_dates, |
| } |
| return pd_read_sql_kwargs |
|
|
|
|
| class Sql(datasets.ArrowBasedBuilder): |
| BUILDER_CONFIG_CLASS = SqlConfig |
|
|
| def _info(self): |
| return datasets.DatasetInfo(features=self.config.features) |
|
|
| def _split_generators(self, dl_manager): |
| return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={})] |
|
|
| def _cast_table(self, pa_table: pa.Table) -> pa.Table: |
| if self.config.features is not None: |
| schema = self.config.features.arrow_schema |
| if all(not require_storage_cast(feature) for feature in self.config.features.values()): |
| |
| pa_table = pa.Table.from_arrays([pa_table[field.name] for field in schema], schema=schema) |
| else: |
| |
| pa_table = table_cast(pa_table, schema) |
| return pa_table |
|
|
| def _generate_tables(self): |
| chunksize = self.config.chunksize |
| sql_reader = pd.read_sql( |
| self.config.sql, self.config.con, chunksize=chunksize, **self.config.pd_read_sql_kwargs |
| ) |
| sql_reader = [sql_reader] if chunksize is None else sql_reader |
| for chunk_idx, df in enumerate(sql_reader): |
| pa_table = pa.Table.from_pandas(df) |
| yield chunk_idx, self._cast_table(pa_table) |
|
|