HF-Dataset-Commander / processor.py
broadfield-dev's picture
Update processor.py
97353a3 verified
raw
history blame
9.31 kB
import json
import logging
import datasets
from datasets import load_dataset, get_dataset_config_names, get_dataset_infos
from huggingface_hub import HfApi
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class DatasetCommandCenter:
def __init__(self, token=None):
self.token = token
def get_dataset_metadata(self, dataset_id):
"""
Step 1: Get available Configs (subsets) and Splits without downloading data.
"""
try:
# 1. Get Configs (e.g., 'en', 'fr' or 'default')
try:
configs = get_dataset_config_names(dataset_id, token=self.token)
except Exception:
# Some datasets have no configs or throw errors, default to 'default' or None
configs = ['default']
# 2. Get Splits for the first config (to pre-populate)
# We will fetch specific splits for other configs dynamically if needed
selected_config = configs[0]
try:
# This fetches metadata (splits, columns) without downloading rows
infos = get_dataset_infos(dataset_id, token=self.token)
# If multiple configs, infos is a dict keyed by config name
if selected_config in infos:
splits = list(infos[selected_config].splits.keys())
else:
# Fallback if structure is flat
splits = list(infos.values())[0].splits.keys()
except:
# Fallback: try to just list simple splits
splits = ['train', 'test', 'validation']
return {
"status": "success",
"configs": configs,
"splits": list(splits)
}
except Exception as e:
return {"status": "error", "message": str(e)}
def get_splits_for_config(self, dataset_id, config_name):
"""
Helper to update splits when user changes the Config dropdown
"""
try:
infos = get_dataset_infos(dataset_id, config_name=config_name, token=self.token)
splits = list(infos[config_name].splits.keys())
return {"status": "success", "splits": splits}
except Exception as e:
# Fallback
return {"status": "success", "splits": ['train', 'test', 'validation']}
def inspect_dataset(self, dataset_id, config, split):
"""
Step 2: Stream actual rows and detect JSON.
"""
try:
# Handle 'default' config edge cases
conf = config if config != 'default' else None
ds_stream = load_dataset(dataset_id, name=conf, split=split, streaming=True, token=self.token)
sample_rows = []
for i, row in enumerate(ds_stream):
if i >= 5: break
# Convert non-serializable objects (like PIL Images) to strings for preview
clean_row = {}
for k, v in row.items():
if not isinstance(v, (str, int, float, bool, list, dict, type(None))):
clean_row[k] = str(v)
else:
clean_row[k] = v
sample_rows.append(clean_row)
if not sample_rows:
return {"status": "error", "message": "Dataset is empty."}
# Analyze Columns
analysis = {}
keys = sample_rows[0].keys()
for k in keys:
sample_val = sample_rows[0][k]
col_type = type(sample_val).__name__
is_json_str = False
# Check if string looks like JSON
if isinstance(sample_val, str):
s = sample_val.strip()
if (s.startswith('{') and s.endswith('}')) or (s.startswith('[') and s.endswith(']')):
try:
json.loads(s)
is_json_str = True
except:
pass
analysis[k] = {
"type": col_type,
"is_json_string": is_json_str
}
return {
"status": "success",
"samples": sample_rows,
"analysis": analysis
}
except Exception as e:
return {"status": "error", "message": str(e)}
def _apply_transformations(self, row, recipe):
"""
Apply Parsing, Renaming, Dropping, Filtering
"""
new_row = row.copy()
# 1. JSON Expansions
if "json_expansions" in recipe:
for item in recipe["json_expansions"]:
col_name = item["col"]
target_keys = item["keys"]
# Check if we need to parse string-json first
source_data = new_row.get(col_name)
parsed_obj = None
# Case A: It is already a dict (Struct)
if isinstance(source_data, dict):
parsed_obj = source_data
# Case B: It is a string (JSON String)
elif isinstance(source_data, str):
try:
parsed_obj = json.loads(source_data)
except:
pass
if parsed_obj:
for key in target_keys:
# Handle Nested Dot Notation (e.g. "meta.url")
val = parsed_obj
parts = key.split('.')
try:
for p in parts:
val = val[p]
# Create new column name (replace dots with underscores)
clean_key = key.replace('.', '_')
new_col_name = f"{col_name}_{clean_key}"
new_row[new_col_name] = val
except:
# Key not found
clean_key = key.replace('.', '_')
new_row[f"{col_name}_{clean_key}"] = None
# 2. Renames
if "renames" in recipe:
for old, new in recipe["renames"].items():
if old in new_row:
new_row[new] = new_row.pop(old)
# 3. Drops
if "drops" in recipe:
for drop_col in recipe["drops"]:
if drop_col in new_row:
del new_row[drop_col]
return new_row
def _passes_filter(self, row, filters):
if not filters: return True
context = row.copy()
for f_str in filters:
try:
# Safety: very basic eval.
if not eval(f_str, {}, context):
return False
except:
return False
return True
def process_and_push(self, source_id, config, split, target_id, recipe, max_rows=None):
logger.info(f"Starting job: {source_id} ({config}/{split}) -> {target_id}")
conf = config if config != 'default' else None
def gen():
ds_stream = load_dataset(source_id, name=conf, split=split, streaming=True, token=self.token)
count = 0
for row in ds_stream:
if max_rows and count >= int(max_rows):
break
# Transform first (so filters apply to NEW schema if needed,
# OR change order depending on preference. Here we filter RAW data usually,
# but for JSON extraction we often filter on extracted fields.
# Let's Apply Transform -> Then Filter to allow filtering on extracted JSON fields)
trans_row = self._apply_transformations(row, recipe)
if self._passes_filter(trans_row, recipe.get("filters", [])):
yield trans_row
count += 1
# Push to Hub
# Note: We must infer features or let HF do it.
# Using a generator allows HF to auto-detect the new schema.
try:
new_dataset = datasets.Dataset.from_generator(gen)
new_dataset.push_to_hub(target_id, token=self.token)
return {"status": "success", "rows_processed": len(new_dataset)}
except Exception as e:
logger.error(e)
raise e
def preview_transform(self, dataset_id, config, split, recipe):
conf = config if config != 'default' else None
ds_stream = load_dataset(dataset_id, name=conf, split=split, streaming=True, token=self.token)
processed = []
for i, row in enumerate(ds_stream):
if len(processed) >= 5: break
trans_row = self._apply_transformations(row, recipe)
if self._passes_filter(trans_row, recipe.get("filters", [])):
processed.append(trans_row)
return processed