broadfield-dev commited on
Commit
4425935
·
verified ·
1 Parent(s): 1a2363c

Update processor.py

Browse files
Files changed (1) hide show
  1. processor.py +117 -140
processor.py CHANGED
@@ -2,7 +2,6 @@ import json
2
  import logging
3
  import datasets
4
  from datasets import load_dataset, get_dataset_config_names, get_dataset_infos
5
- from huggingface_hub import HfApi
6
 
7
  # Configure logging
8
  logging.basicConfig(level=logging.INFO)
@@ -13,185 +12,166 @@ class DatasetCommandCenter:
13
  self.token = token
14
 
15
  def get_dataset_metadata(self, dataset_id):
16
- """
17
- Step 1: Get available Configs (subsets) and Splits without downloading data.
18
- """
19
  try:
20
- # 1. Get Configs (e.g., 'en', 'fr' or 'default')
21
  try:
22
  configs = get_dataset_config_names(dataset_id, token=self.token)
23
- except Exception:
24
- # Some datasets have no configs or throw errors, default to 'default' or None
25
  configs = ['default']
26
-
27
- # 2. Get Splits for the first config (to pre-populate)
28
- # We will fetch specific splits for other configs dynamically if needed
29
- selected_config = configs[0]
30
 
 
31
  try:
32
- # This fetches metadata (splits, columns) without downloading rows
33
  infos = get_dataset_infos(dataset_id, token=self.token)
34
- # If multiple configs, infos is a dict keyed by config name
35
- if selected_config in infos:
36
- splits = list(infos[selected_config].splits.keys())
37
  else:
38
- # Fallback if structure is flat
39
  splits = list(infos.values())[0].splits.keys()
40
  except:
41
- # Fallback: try to just list simple splits
42
  splits = ['train', 'test', 'validation']
43
 
44
- return {
45
- "status": "success",
46
- "configs": configs,
47
- "splits": list(splits)
48
- }
49
  except Exception as e:
50
  return {"status": "error", "message": str(e)}
51
 
52
  def get_splits_for_config(self, dataset_id, config_name):
53
- """
54
- Helper to update splits when user changes the Config dropdown
55
- """
56
  try:
57
  infos = get_dataset_infos(dataset_id, config_name=config_name, token=self.token)
58
  splits = list(infos[config_name].splits.keys())
59
  return {"status": "success", "splits": splits}
60
- except Exception as e:
61
- # Fallback
62
  return {"status": "success", "splits": ['train', 'test', 'validation']}
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  def inspect_dataset(self, dataset_id, config, split):
65
  """
66
- Step 2: Stream actual rows and detect JSON.
67
  """
68
  try:
69
- # Handle 'default' config edge cases
70
  conf = config if config != 'default' else None
71
-
72
  ds_stream = load_dataset(dataset_id, name=conf, split=split, streaming=True, token=self.token)
73
 
74
  sample_rows = []
 
 
 
75
  for i, row in enumerate(ds_stream):
76
- if i >= 5: break
77
- # Convert non-serializable objects (like PIL Images) to strings for preview
 
78
  clean_row = {}
79
  for k, v in row.items():
 
80
  if not isinstance(v, (str, int, float, bool, list, dict, type(None))):
81
- clean_row[k] = str(v)
82
  else:
83
  clean_row[k] = v
84
  sample_rows.append(clean_row)
85
 
86
- if not sample_rows:
87
- return {"status": "error", "message": "Dataset is empty."}
 
88
 
89
- # Analyze Columns
90
- analysis = {}
91
- keys = sample_rows[0].keys()
92
 
93
- for k in keys:
94
- sample_val = sample_rows[0][k]
95
- col_type = type(sample_val).__name__
96
- is_json_str = False
97
-
98
- # Check if string looks like JSON
99
- if isinstance(sample_val, str):
100
- s = sample_val.strip()
101
- if (s.startswith('{') and s.endswith('}')) or (s.startswith('[') and s.endswith(']')):
102
- try:
103
- json.loads(s)
104
- is_json_str = True
105
- except:
106
- pass
107
-
108
- analysis[k] = {
109
- "type": col_type,
110
- "is_json_string": is_json_str
111
- }
112
 
113
  return {
114
  "status": "success",
115
- "samples": sample_rows,
116
- "analysis": analysis
 
117
  }
118
  except Exception as e:
119
  return {"status": "error", "message": str(e)}
120
 
121
- def _apply_transformations(self, row, recipe):
122
  """
123
- Apply Parsing, Renaming, Dropping, Filtering
124
  """
125
- new_row = row.copy()
 
126
 
127
- # 1. JSON Expansions
128
- if "json_expansions" in recipe:
129
- for item in recipe["json_expansions"]:
130
- col_name = item["col"]
131
- target_keys = item["keys"]
132
-
133
- # Check if we need to parse string-json first
134
- source_data = new_row.get(col_name)
135
-
136
- parsed_obj = None
137
-
138
- # Case A: It is already a dict (Struct)
139
- if isinstance(source_data, dict):
140
- parsed_obj = source_data
141
- # Case B: It is a string (JSON String)
142
- elif isinstance(source_data, str):
143
  try:
144
- parsed_obj = json.loads(source_data)
145
  except:
146
- pass
147
-
148
- if parsed_obj:
149
- for key in target_keys:
150
- # Handle Nested Dot Notation (e.g. "meta.url")
151
- val = parsed_obj
152
- parts = key.split('.')
153
- try:
154
- for p in parts:
155
- val = val[p]
156
-
157
- # Create new column name (replace dots with underscores)
158
- clean_key = key.replace('.', '_')
159
- new_col_name = f"{col_name}_{clean_key}"
160
- new_row[new_col_name] = val
161
- except:
162
- # Key not found
163
- clean_key = key.replace('.', '_')
164
- new_row[f"{col_name}_{clean_key}"] = None
165
-
166
- # 2. Renames
167
- if "renames" in recipe:
168
- for old, new in recipe["renames"].items():
169
- if old in new_row:
170
- new_row[new] = new_row.pop(old)
171
-
172
- # 3. Drops
173
- if "drops" in recipe:
174
- for drop_col in recipe["drops"]:
175
- if drop_col in new_row:
176
- del new_row[drop_col]
177
 
 
 
 
 
 
 
 
 
 
178
  return new_row
179
 
180
- def _passes_filter(self, row, filters):
181
- if not filters: return True
182
- context = row.copy()
183
- for f_str in filters:
184
- try:
185
- # Safety: very basic eval.
186
- if not eval(f_str, {}, context):
187
- return False
188
- except:
189
- return False
190
- return True
 
 
 
 
 
191
 
192
  def process_and_push(self, source_id, config, split, target_id, recipe, max_rows=None):
193
- logger.info(f"Starting job: {source_id} ({config}/{split}) -> {target_id}")
194
-
195
  conf = config if config != 'default' else None
196
 
197
  def gen():
@@ -201,35 +181,32 @@ class DatasetCommandCenter:
201
  if max_rows and count >= int(max_rows):
202
  break
203
 
204
- # Transform first (so filters apply to NEW schema if needed,
205
- # OR change order depending on preference. Here we filter RAW data usually,
206
- # but for JSON extraction we often filter on extracted fields.
207
- # Let's Apply Transform -> Then Filter to allow filtering on extracted JSON fields)
208
-
209
- trans_row = self._apply_transformations(row, recipe)
210
-
211
- if self._passes_filter(trans_row, recipe.get("filters", [])):
212
- yield trans_row
213
  count += 1
214
 
215
- # Push to Hub
216
- # Note: We must infer features or let HF do it.
217
- # Using a generator allows HF to auto-detect the new schema.
218
  try:
 
219
  new_dataset = datasets.Dataset.from_generator(gen)
220
  new_dataset.push_to_hub(target_id, token=self.token)
221
  return {"status": "success", "rows_processed": len(new_dataset)}
222
  except Exception as e:
223
  logger.error(e)
224
  raise e
225
-
226
  def preview_transform(self, dataset_id, config, split, recipe):
227
  conf = config if config != 'default' else None
228
  ds_stream = load_dataset(dataset_id, name=conf, split=split, streaming=True, token=self.token)
229
  processed = []
230
- for i, row in enumerate(ds_stream):
 
231
  if len(processed) >= 5: break
232
- trans_row = self._apply_transformations(row, recipe)
233
- if self._passes_filter(trans_row, recipe.get("filters", [])):
234
- processed.append(trans_row)
 
 
235
  return processed
 
2
  import logging
3
  import datasets
4
  from datasets import load_dataset, get_dataset_config_names, get_dataset_infos
 
5
 
6
  # Configure logging
7
  logging.basicConfig(level=logging.INFO)
 
12
  self.token = token
13
 
14
  def get_dataset_metadata(self, dataset_id):
 
 
 
15
  try:
 
16
  try:
17
  configs = get_dataset_config_names(dataset_id, token=self.token)
18
+ except:
 
19
  configs = ['default']
 
 
 
 
20
 
21
+ # Try to fetch splits for the first config
22
  try:
 
23
  infos = get_dataset_infos(dataset_id, token=self.token)
24
+ first_conf = configs[0]
25
+ if first_conf in infos:
26
+ splits = list(infos[first_conf].splits.keys())
27
  else:
 
28
  splits = list(infos.values())[0].splits.keys()
29
  except:
 
30
  splits = ['train', 'test', 'validation']
31
 
32
+ return {"status": "success", "configs": configs, "splits": list(splits)}
 
 
 
 
33
  except Exception as e:
34
  return {"status": "error", "message": str(e)}
35
 
36
  def get_splits_for_config(self, dataset_id, config_name):
 
 
 
37
  try:
38
  infos = get_dataset_infos(dataset_id, config_name=config_name, token=self.token)
39
  splits = list(infos[config_name].splits.keys())
40
  return {"status": "success", "splits": splits}
41
+ except:
 
42
  return {"status": "success", "splits": ['train', 'test', 'validation']}
43
 
44
+ def _flatten_object(self, obj, parent_key='', sep='.'):
45
+ """
46
+ Recursively finds all keys in a nested dictionary (or JSON string).
47
+ Returns a dict of { 'path': sample_value }.
48
+ """
49
+ items = {}
50
+
51
+ # If it's a string, try to parse it as JSON first
52
+ if isinstance(obj, str):
53
+ obj = obj.strip()
54
+ if (obj.startswith('{') and obj.endswith('}')):
55
+ try:
56
+ obj = json.loads(obj)
57
+ except:
58
+ pass # It's just a string
59
+
60
+ if isinstance(obj, dict):
61
+ for k, v in obj.items():
62
+ new_key = f"{parent_key}{sep}{k}" if parent_key else k
63
+ items.update(self._flatten_object(v, new_key, sep=sep))
64
+ else:
65
+ # It's a leaf node (int, str, list, etc.)
66
+ items[parent_key] = obj
67
+
68
+ return items
69
+
70
  def inspect_dataset(self, dataset_id, config, split):
71
  """
72
+ Scans first N rows to build a map of ALL available fields (including nested JSON).
73
  """
74
  try:
 
75
  conf = config if config != 'default' else None
 
76
  ds_stream = load_dataset(dataset_id, name=conf, split=split, streaming=True, token=self.token)
77
 
78
  sample_rows = []
79
+ available_paths = set()
80
+
81
+ # We scan 20 rows to find schema variations (some JSON keys might be optional)
82
  for i, row in enumerate(ds_stream):
83
+ if i >= 20: break
84
+
85
+ # Store a clean version for UI preview
86
  clean_row = {}
87
  for k, v in row.items():
88
+ # Handle bytes/images
89
  if not isinstance(v, (str, int, float, bool, list, dict, type(None))):
90
+ clean_row[k] = f"<{type(v).__name__}>"
91
  else:
92
  clean_row[k] = v
93
  sample_rows.append(clean_row)
94
 
95
+ # Schema Inference: Flatten this row to find all possible dot-notation paths
96
+ flattened = self._flatten_object(row)
97
+ available_paths.update(flattened.keys())
98
 
99
+ # Sort paths naturally
100
+ sorted_paths = sorted(list(available_paths))
 
101
 
102
+ # Group paths by top-level column for the UI
103
+ schema_tree = {}
104
+ for path in sorted_paths:
105
+ root = path.split('.')[0]
106
+ if root not in schema_tree:
107
+ schema_tree[root] = []
108
+ schema_tree[root].append(path)
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
  return {
111
  "status": "success",
112
+ "samples": sample_rows[:5], # Send 5 to frontend
113
+ "schema_tree": schema_tree, # { 'meta': ['meta', 'meta.url', 'meta.id'] }
114
+ "dataset_id": dataset_id
115
  }
116
  except Exception as e:
117
  return {"status": "error", "message": str(e)}
118
 
119
+ def _get_value_by_path(self, row, path):
120
  """
121
+ Extracts value using dot notation, parsing JSON strings on the fly if needed.
122
  """
123
+ keys = path.split('.')
124
+ current_data = row
125
 
126
+ try:
127
+ for i, key in enumerate(keys):
128
+ # 1. If current_data is a JSON string, parse it
129
+ if isinstance(current_data, str):
 
 
 
 
 
 
 
 
 
 
 
 
130
  try:
131
+ current_data = json.loads(current_data)
132
  except:
133
+ return None # Parsing failed
134
+
135
+ # 2. Access key
136
+ if isinstance(current_data, dict) and key in current_data:
137
+ current_data = current_data[key]
138
+ else:
139
+ return None # Key missing
140
+
141
+ return current_data
142
+ except:
143
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
+ def _apply_projection(self, row, recipe):
146
+ """
147
+ Constructs a NEW row based on the target columns defined in recipe.
148
+ """
149
+ new_row = {}
150
+ for target in recipe['columns']:
151
+ # target = { "name": "new_col_name", "source": "old_col.nested.key" }
152
+ val = self._get_value_by_path(row, target['source'])
153
+ new_row[target['name']] = val
154
  return new_row
155
 
156
+ def _passes_filter(self, row, filter_str):
157
+ """
158
+ Filters are applied to the SOURCE row structure (before projection).
159
+ """
160
+ if not filter_str or not filter_str.strip():
161
+ return True
162
+ try:
163
+ # We must handle cases where 'row' has nested objects unparsed?
164
+ # For simplicity, we eval on the raw row dictionary.
165
+ # Users can use python: `json.loads(row['meta'])['url'] == ...`
166
+ # Or we can support the flattened context?
167
+ # Let's stick to raw row context for now.
168
+ context = row.copy()
169
+ return eval(filter_str, {}, context)
170
+ except:
171
+ return False # Fail safe
172
 
173
  def process_and_push(self, source_id, config, split, target_id, recipe, max_rows=None):
174
+ logger.info(f"Starting projection job: {source_id} -> {target_id}")
 
175
  conf = config if config != 'default' else None
176
 
177
  def gen():
 
181
  if max_rows and count >= int(max_rows):
182
  break
183
 
184
+ # 1. Filter (Source)
185
+ if self._passes_filter(row, recipe.get('filter_rule')):
186
+ # 2. Project (Build new row)
187
+ new_row = self._apply_projection(row, recipe)
188
+ yield new_row
 
 
 
 
189
  count += 1
190
 
 
 
 
191
  try:
192
+ # Create new dataset from generator (Auto-infers schema from first yielded dict)
193
  new_dataset = datasets.Dataset.from_generator(gen)
194
  new_dataset.push_to_hub(target_id, token=self.token)
195
  return {"status": "success", "rows_processed": len(new_dataset)}
196
  except Exception as e:
197
  logger.error(e)
198
  raise e
199
+
200
  def preview_transform(self, dataset_id, config, split, recipe):
201
  conf = config if config != 'default' else None
202
  ds_stream = load_dataset(dataset_id, name=conf, split=split, streaming=True, token=self.token)
203
  processed = []
204
+
205
+ for row in ds_stream:
206
  if len(processed) >= 5: break
207
+
208
+ if self._passes_filter(row, recipe.get('filter_rule')):
209
+ new_row = self._apply_projection(row, recipe)
210
+ processed.append(new_row)
211
+
212
  return processed