broadfield-dev commited on
Commit
a76c50f
·
verified ·
1 Parent(s): f37234a

Update processor.py

Browse files
Files changed (1) hide show
  1. processor.py +145 -106
processor.py CHANGED
@@ -3,7 +3,6 @@ import logging
3
  import datasets
4
  from datasets import load_dataset, get_dataset_config_names, get_dataset_infos
5
 
6
- # Configure logging
7
  logging.basicConfig(level=logging.INFO)
8
  logger = logging.getLogger(__name__)
9
 
@@ -17,8 +16,6 @@ class DatasetCommandCenter:
17
  configs = get_dataset_config_names(dataset_id, token=self.token)
18
  except:
19
  configs = ['default']
20
-
21
- # Try to fetch splits for the first config
22
  try:
23
  infos = get_dataset_infos(dataset_id, token=self.token)
24
  first_conf = configs[0]
@@ -28,7 +25,6 @@ class DatasetCommandCenter:
28
  splits = list(infos.values())[0].splits.keys()
29
  except:
30
  splits = ['train', 'test', 'validation']
31
-
32
  return {"status": "success", "configs": configs, "splits": list(splits)}
33
  except Exception as e:
34
  return {"status": "error", "message": str(e)}
@@ -41,172 +37,215 @@ class DatasetCommandCenter:
41
  except:
42
  return {"status": "success", "splits": ['train', 'test', 'validation']}
43
 
44
- def _flatten_object(self, obj, parent_key='', sep='.'):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  """
46
- Recursively finds all keys in a nested dictionary (or JSON string).
47
- Returns a dict of { 'path': sample_value }.
 
48
  """
49
- items = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
- # If it's a string, try to parse it as JSON first
 
 
 
 
 
 
 
 
 
 
52
  if isinstance(obj, str):
53
- obj = obj.strip()
54
- if (obj.startswith('{') and obj.endswith('}')):
55
  try:
56
- obj = json.loads(obj)
57
  except:
58
- pass # It's just a string
59
 
60
  if isinstance(obj, dict):
61
  for k, v in obj.items():
62
- new_key = f"{parent_key}{sep}{k}" if parent_key else k
63
- items.update(self._flatten_object(v, new_key, sep=sep))
64
- else:
65
- # It's a leaf node (int, str, list, etc.)
66
- items[parent_key] = obj
 
 
 
67
 
68
  return items
69
 
70
  def inspect_dataset(self, dataset_id, config, split):
71
- """
72
- Scans first N rows to build a map of ALL available fields (including nested JSON).
73
- """
74
  try:
75
  conf = config if config != 'default' else None
76
  ds_stream = load_dataset(dataset_id, name=conf, split=split, streaming=True, token=self.token)
77
 
78
  sample_rows = []
79
- available_paths = set()
80
-
81
- # We scan 20 rows to find schema variations (some JSON keys might be optional)
82
  for i, row in enumerate(ds_stream):
83
- if i >= 20: break
84
 
85
- # Store a clean version for UI preview
86
  clean_row = {}
87
  for k, v in row.items():
88
- # Handle bytes/images
89
  if not isinstance(v, (str, int, float, bool, list, dict, type(None))):
90
- clean_row[k] = f"<{type(v).__name__}>"
91
  else:
92
  clean_row[k] = v
93
  sample_rows.append(clean_row)
94
 
95
- # Schema Inference: Flatten this row to find all possible dot-notation paths
96
- flattened = self._flatten_object(row)
97
- available_paths.update(flattened.keys())
98
-
99
- # Sort paths naturally
100
- sorted_paths = sorted(list(available_paths))
101
-
102
- # Group paths by top-level column for the UI
103
- schema_tree = {}
104
- for path in sorted_paths:
105
- root = path.split('.')[0]
106
- if root not in schema_tree:
107
- schema_tree[root] = []
108
- schema_tree[root].append(path)
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
  return {
111
- "status": "success",
112
- "samples": sample_rows[:5], # Send 5 to frontend
113
- "schema_tree": schema_tree, # { 'meta': ['meta', 'meta.url', 'meta.id'] }
114
  "dataset_id": dataset_id
115
  }
116
  except Exception as e:
117
  return {"status": "error", "message": str(e)}
118
 
119
- def _get_value_by_path(self, row, path):
120
- """
121
- Extracts value using dot notation, parsing JSON strings on the fly if needed.
122
- """
123
- keys = path.split('.')
124
- current_data = row
125
-
126
- try:
127
- for i, key in enumerate(keys):
128
- # 1. If current_data is a JSON string, parse it
129
- if isinstance(current_data, str):
130
- try:
131
- current_data = json.loads(current_data)
132
- except:
133
- return None # Parsing failed
134
-
135
- # 2. Access key
136
- if isinstance(current_data, dict) and key in current_data:
137
- current_data = current_data[key]
138
- else:
139
- return None # Key missing
140
-
141
- return current_data
142
- except:
143
- return None
144
-
145
  def _apply_projection(self, row, recipe):
146
- """
147
- Constructs a NEW row based on the target columns defined in recipe.
148
- """
149
  new_row = {}
150
- for target in recipe['columns']:
151
- # target = { "name": "new_col_name", "source": "old_col.nested.key" }
152
- val = self._get_value_by_path(row, target['source'])
153
- new_row[target['name']] = val
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  return new_row
155
 
156
  def _passes_filter(self, row, filter_str):
157
- """
158
- Filters are applied to the SOURCE row structure (before projection).
159
- """
160
- if not filter_str or not filter_str.strip():
161
- return True
162
  try:
163
- # We must handle cases where 'row' has nested objects unparsed?
164
- # For simplicity, we eval on the raw row dictionary.
165
- # Users can use python: `json.loads(row['meta'])['url'] == ...`
166
- # Or we can support the flattened context?
167
- # Let's stick to raw row context for now.
168
  context = row.copy()
169
  return eval(filter_str, {}, context)
170
  except:
171
- return False # Fail safe
172
 
173
  def process_and_push(self, source_id, config, split, target_id, recipe, max_rows=None):
174
- logger.info(f"Starting projection job: {source_id} -> {target_id}")
175
  conf = config if config != 'default' else None
176
 
177
  def gen():
178
  ds_stream = load_dataset(source_id, name=conf, split=split, streaming=True, token=self.token)
179
  count = 0
180
  for row in ds_stream:
181
- if max_rows and count >= int(max_rows):
182
- break
183
 
184
- # 1. Filter (Source)
185
  if self._passes_filter(row, recipe.get('filter_rule')):
186
- # 2. Project (Build new row)
187
- new_row = self._apply_projection(row, recipe)
188
- yield new_row
189
  count += 1
190
-
191
  try:
192
- # Create new dataset from generator (Auto-infers schema from first yielded dict)
193
  new_dataset = datasets.Dataset.from_generator(gen)
194
  new_dataset.push_to_hub(target_id, token=self.token)
195
  return {"status": "success", "rows_processed": len(new_dataset)}
196
  except Exception as e:
197
- logger.error(e)
198
- raise e
199
-
200
  def preview_transform(self, dataset_id, config, split, recipe):
201
  conf = config if config != 'default' else None
202
  ds_stream = load_dataset(dataset_id, name=conf, split=split, streaming=True, token=self.token)
203
  processed = []
204
-
205
  for row in ds_stream:
206
  if len(processed) >= 5: break
207
-
208
  if self._passes_filter(row, recipe.get('filter_rule')):
209
- new_row = self._apply_projection(row, recipe)
210
- processed.append(new_row)
211
-
212
  return processed
 
3
  import datasets
4
  from datasets import load_dataset, get_dataset_config_names, get_dataset_infos
5
 
 
6
  logging.basicConfig(level=logging.INFO)
7
  logger = logging.getLogger(__name__)
8
 
 
16
  configs = get_dataset_config_names(dataset_id, token=self.token)
17
  except:
18
  configs = ['default']
 
 
19
  try:
20
  infos = get_dataset_infos(dataset_id, token=self.token)
21
  first_conf = configs[0]
 
25
  splits = list(infos.values())[0].splits.keys()
26
  except:
27
  splits = ['train', 'test', 'validation']
 
28
  return {"status": "success", "configs": configs, "splits": list(splits)}
29
  except Exception as e:
30
  return {"status": "error", "message": str(e)}
 
37
  except:
38
  return {"status": "success", "splits": ['train', 'test', 'validation']}
39
 
40
+ # --- HELPER: Recursive JSON/Dot Notation Getter ---
41
+ def _get_value_by_path(self, obj, path):
42
+ if not path: return obj
43
+ keys = path.split('.')
44
+ current = obj
45
+
46
+ for key in keys:
47
+ # Auto-parse JSON string if encountered
48
+ if isinstance(current, str):
49
+ s = current.strip()
50
+ if (s.startswith('{') and s.endswith('}')) or (s.startswith('[') and s.endswith(']')):
51
+ try:
52
+ current = json.loads(s)
53
+ except:
54
+ pass
55
+
56
+ if isinstance(current, dict) and key in current:
57
+ current = current[key]
58
+ else:
59
+ return None
60
+ return current
61
+
62
+ # --- HELPER: List Search Logic ---
63
+ def _extract_from_list_logic(self, row, source_col, filter_key, filter_val, target_path):
64
  """
65
+ Logic: Look inside row[source_col] (which is a list).
66
+ Find first item where item[filter_key] == filter_val.
67
+ Then extract item[target_path].
68
  """
69
+ # 1. Get the list (handling JSON string if needed)
70
+ data = row.get(source_col)
71
+ if isinstance(data, str):
72
+ try:
73
+ data = json.loads(data)
74
+ except:
75
+ return None
76
+
77
+ if not isinstance(data, list):
78
+ return None
79
+
80
+ # 2. Search the list
81
+ matched_item = None
82
+ for item in data:
83
+ # We treat values as strings for comparison to be safe
84
+ if str(item.get(filter_key, '')) == str(filter_val):
85
+ matched_item = item
86
+ break
87
+
88
+ if matched_item:
89
+ # 3. Extract the target (supporting nested json parsing via dot notation)
90
+ # e.g. target_path = "content.analysis"
91
+ return self._get_value_by_path(matched_item, target_path)
92
 
93
+ return None
94
+
95
+ def _flatten_schema(self, obj, parent='', visited=None):
96
+ if visited is None: visited = set()
97
+ items = []
98
+
99
+ # Avoid infinite recursion
100
+ if id(obj) in visited: return []
101
+ visited.add(id(obj))
102
+
103
+ # Handle JSON strings
104
  if isinstance(obj, str):
105
+ s = obj.strip()
106
+ if (s.startswith('{') and s.endswith('}')) or (s.startswith('[') and s.endswith(']')):
107
  try:
108
+ obj = json.loads(s)
109
  except:
110
+ pass
111
 
112
  if isinstance(obj, dict):
113
  for k, v in obj.items():
114
+ full_key = f"{parent}.{k}" if parent else k
115
+ items.append((full_key, type(v).__name__))
116
+ items.extend(self._flatten_schema(v, full_key, visited))
117
+ elif isinstance(obj, list) and len(obj) > 0:
118
+ # For lists, we just peek at the first item to guess schema
119
+ full_key = f"{parent}[]" if parent else "[]"
120
+ items.append((parent, "List")) # Mark the parent as a List
121
+ items.extend(self._flatten_schema(obj[0], full_key, visited))
122
 
123
  return items
124
 
125
  def inspect_dataset(self, dataset_id, config, split):
 
 
 
126
  try:
127
  conf = config if config != 'default' else None
128
  ds_stream = load_dataset(dataset_id, name=conf, split=split, streaming=True, token=self.token)
129
 
130
  sample_rows = []
131
+ schema_map = {} # stores { "col_name": { "is_list": bool, "keys": [] } }
132
+
 
133
  for i, row in enumerate(ds_stream):
134
+ if i >= 10: break
135
 
136
+ # Create clean sample for UI
137
  clean_row = {}
138
  for k, v in row.items():
 
139
  if not isinstance(v, (str, int, float, bool, list, dict, type(None))):
140
+ clean_row[k] = str(v)
141
  else:
142
  clean_row[k] = v
143
  sample_rows.append(clean_row)
144
 
145
+ # Analyze Schema
146
+ for k, v in row.items():
147
+ if k not in schema_map:
148
+ schema_map[k] = {"is_list": False, "keys": set()}
149
+
150
+ # Check if it's a list (or json-string list)
151
+ val = v
152
+ if isinstance(val, str):
153
+ try:
154
+ val = json.loads(val)
155
+ except: pass
156
+
157
+ if isinstance(val, list):
158
+ schema_map[k]["is_list"] = True
159
+ if len(val) > 0 and isinstance(val[0], dict):
160
+ schema_map[k]["keys"].update(val[0].keys())
161
+ elif isinstance(val, dict):
162
+ schema_map[k]["keys"].update(val.keys())
163
+
164
+ # Format schema for UI
165
+ formatted_schema = {}
166
+ for k, info in schema_map.items():
167
+ formatted_schema[k] = {
168
+ "type": "List" if info["is_list"] else "Object",
169
+ "keys": list(info["keys"])
170
+ }
171
 
172
  return {
173
+ "status": "success",
174
+ "samples": sample_rows,
175
+ "schema": formatted_schema,
176
  "dataset_id": dataset_id
177
  }
178
  except Exception as e:
179
  return {"status": "error", "message": str(e)}
180
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  def _apply_projection(self, row, recipe):
 
 
 
182
  new_row = {}
183
+ for col_def in recipe['columns']:
184
+ t_type = col_def.get('type', 'simple')
185
+
186
+ if t_type == 'simple':
187
+ # Standard Dot Notation
188
+ new_row[col_def['name']] = self._get_value_by_path(row, col_def['source'])
189
+
190
+ elif t_type == 'list_search':
191
+ # GET x WHERE y=z
192
+ val = self._extract_from_list_logic(
193
+ row,
194
+ col_def['source'],
195
+ col_def['filter_key'],
196
+ col_def['filter_val'],
197
+ col_def['target_key']
198
+ )
199
+ new_row[col_def['name']] = val
200
+
201
+ elif t_type == 'python':
202
+ # Advanced Python Eval
203
+ try:
204
+ context = row.copy()
205
+ # We inject 'json' module into context for user scripts
206
+ context['json'] = json
207
+ val = eval(col_def['expression'], {}, context)
208
+ new_row[col_def['name']] = val
209
+ except:
210
+ new_row[col_def['name']] = None
211
+
212
  return new_row
213
 
214
  def _passes_filter(self, row, filter_str):
215
+ if not filter_str: return True
 
 
 
 
216
  try:
 
 
 
 
 
217
  context = row.copy()
218
  return eval(filter_str, {}, context)
219
  except:
220
+ return False
221
 
222
  def process_and_push(self, source_id, config, split, target_id, recipe, max_rows=None):
223
+ logger.info(f"Job started: {source_id}")
224
  conf = config if config != 'default' else None
225
 
226
  def gen():
227
  ds_stream = load_dataset(source_id, name=conf, split=split, streaming=True, token=self.token)
228
  count = 0
229
  for row in ds_stream:
230
+ if max_rows and count >= int(max_rows): break
 
231
 
 
232
  if self._passes_filter(row, recipe.get('filter_rule')):
233
+ yield self._apply_projection(row, recipe)
 
 
234
  count += 1
235
+
236
  try:
 
237
  new_dataset = datasets.Dataset.from_generator(gen)
238
  new_dataset.push_to_hub(target_id, token=self.token)
239
  return {"status": "success", "rows_processed": len(new_dataset)}
240
  except Exception as e:
241
+ return {"status": "error", "message": str(e)}
242
+
 
243
  def preview_transform(self, dataset_id, config, split, recipe):
244
  conf = config if config != 'default' else None
245
  ds_stream = load_dataset(dataset_id, name=conf, split=split, streaming=True, token=self.token)
246
  processed = []
 
247
  for row in ds_stream:
248
  if len(processed) >= 5: break
 
249
  if self._passes_filter(row, recipe.get('filter_rule')):
250
+ processed.append(self._apply_projection(row, recipe))
 
 
251
  return processed