File size: 9,307 Bytes
5c97387
 
97353a3
 
5c97387
 
 
 
 
 
 
 
 
 
97353a3
5c97387
97353a3
5c97387
 
97353a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5c97387
 
 
 
97353a3
 
 
 
 
 
 
 
 
 
 
5c97387
 
 
97353a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5c97387
 
 
 
97353a3
5c97387
 
 
 
 
 
97353a3
5c97387
 
 
 
 
 
 
97353a3
 
 
 
 
 
5c97387
97353a3
 
 
 
 
5c97387
97353a3
 
 
 
 
 
 
 
 
 
 
 
 
 
5c97387
 
97353a3
 
 
 
 
5c97387
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97353a3
5c97387
 
 
97353a3
5c97387
 
97353a3
5c97387
 
 
97353a3
 
5c97387
97353a3
5c97387
 
97353a3
5c97387
 
 
 
 
97353a3
 
 
 
 
 
 
 
 
5c97387
 
97353a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
import json
import logging
import datasets
from datasets import load_dataset, get_dataset_config_names, get_dataset_infos
from huggingface_hub import HfApi

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class DatasetCommandCenter:
    def __init__(self, token=None):
        self.token = token

    def get_dataset_metadata(self, dataset_id):
        """
        Step 1: Get available Configs (subsets) and Splits without downloading data.
        """
        try:
            # 1. Get Configs (e.g., 'en', 'fr' or 'default')
            try:
                configs = get_dataset_config_names(dataset_id, token=self.token)
            except Exception:
                # Some datasets have no configs or throw errors, default to 'default' or None
                configs = ['default']

            # 2. Get Splits for the first config (to pre-populate)
            # We will fetch specific splits for other configs dynamically if needed
            selected_config = configs[0]
            
            try:
                # This fetches metadata (splits, columns) without downloading rows
                infos = get_dataset_infos(dataset_id, token=self.token)
                # If multiple configs, infos is a dict keyed by config name
                if selected_config in infos:
                    splits = list(infos[selected_config].splits.keys())
                else:
                    # Fallback if structure is flat
                    splits = list(infos.values())[0].splits.keys()
            except:
                # Fallback: try to just list simple splits
                splits = ['train', 'test', 'validation']

            return {
                "status": "success",
                "configs": configs,
                "splits": list(splits)
            }
        except Exception as e:
            return {"status": "error", "message": str(e)}

    def get_splits_for_config(self, dataset_id, config_name):
        """
        Helper to update splits when user changes the Config dropdown
        """
        try:
            infos = get_dataset_infos(dataset_id, config_name=config_name, token=self.token)
            splits = list(infos[config_name].splits.keys())
            return {"status": "success", "splits": splits}
        except Exception as e:
             # Fallback
            return {"status": "success", "splits": ['train', 'test', 'validation']}

    def inspect_dataset(self, dataset_id, config, split):
        """
        Step 2: Stream actual rows and detect JSON.
        """
        try:
            # Handle 'default' config edge cases
            conf = config if config != 'default' else None
            
            ds_stream = load_dataset(dataset_id, name=conf, split=split, streaming=True, token=self.token)
            
            sample_rows = []
            for i, row in enumerate(ds_stream):
                if i >= 5: break
                # Convert non-serializable objects (like PIL Images) to strings for preview
                clean_row = {}
                for k, v in row.items():
                    if not isinstance(v, (str, int, float, bool, list, dict, type(None))):
                        clean_row[k] = str(v)
                    else:
                        clean_row[k] = v
                sample_rows.append(clean_row)

            if not sample_rows:
                return {"status": "error", "message": "Dataset is empty."}

            # Analyze Columns
            analysis = {}
            keys = sample_rows[0].keys()
            
            for k in keys:
                sample_val = sample_rows[0][k]
                col_type = type(sample_val).__name__
                is_json_str = False
                
                # Check if string looks like JSON
                if isinstance(sample_val, str):
                    s = sample_val.strip()
                    if (s.startswith('{') and s.endswith('}')) or (s.startswith('[') and s.endswith(']')):
                        try:
                            json.loads(s)
                            is_json_str = True
                        except:
                            pass
                
                analysis[k] = {
                    "type": col_type,
                    "is_json_string": is_json_str
                }

            return {
                "status": "success",
                "samples": sample_rows,
                "analysis": analysis
            }
        except Exception as e:
            return {"status": "error", "message": str(e)}

    def _apply_transformations(self, row, recipe):
        """
        Apply Parsing, Renaming, Dropping, Filtering
        """
        new_row = row.copy()
        
        # 1. JSON Expansions
        if "json_expansions" in recipe:
            for item in recipe["json_expansions"]:
                col_name = item["col"]
                target_keys = item["keys"] 
                
                # Check if we need to parse string-json first
                source_data = new_row.get(col_name)
                
                parsed_obj = None
                
                # Case A: It is already a dict (Struct)
                if isinstance(source_data, dict):
                    parsed_obj = source_data
                # Case B: It is a string (JSON String)
                elif isinstance(source_data, str):
                    try:
                        parsed_obj = json.loads(source_data)
                    except:
                        pass
                
                if parsed_obj:
                    for key in target_keys:
                        # Handle Nested Dot Notation (e.g. "meta.url")
                        val = parsed_obj
                        parts = key.split('.')
                        try:
                            for p in parts:
                                val = val[p]
                            
                            # Create new column name (replace dots with underscores)
                            clean_key = key.replace('.', '_')
                            new_col_name = f"{col_name}_{clean_key}"
                            new_row[new_col_name] = val
                        except:
                            # Key not found
                            clean_key = key.replace('.', '_')
                            new_row[f"{col_name}_{clean_key}"] = None

        # 2. Renames
        if "renames" in recipe:
            for old, new in recipe["renames"].items():
                if old in new_row:
                    new_row[new] = new_row.pop(old)

        # 3. Drops
        if "drops" in recipe:
            for drop_col in recipe["drops"]:
                if drop_col in new_row:
                    del new_row[drop_col]

        return new_row

    def _passes_filter(self, row, filters):
        if not filters: return True
        context = row.copy()
        for f_str in filters:
            try:
                # Safety: very basic eval. 
                if not eval(f_str, {}, context):
                    return False
            except:
                return False
        return True

    def process_and_push(self, source_id, config, split, target_id, recipe, max_rows=None):
        logger.info(f"Starting job: {source_id} ({config}/{split}) -> {target_id}")
        
        conf = config if config != 'default' else None
        
        def gen():
            ds_stream = load_dataset(source_id, name=conf, split=split, streaming=True, token=self.token)
            count = 0
            for row in ds_stream:
                if max_rows and count >= int(max_rows):
                    break
                
                # Transform first (so filters apply to NEW schema if needed, 
                # OR change order depending on preference. Here we filter RAW data usually, 
                # but for JSON extraction we often filter on extracted fields. 
                # Let's Apply Transform -> Then Filter to allow filtering on extracted JSON fields)
                
                trans_row = self._apply_transformations(row, recipe)
                
                if self._passes_filter(trans_row, recipe.get("filters", [])):
                    yield trans_row
                    count += 1

        # Push to Hub
        # Note: We must infer features or let HF do it. 
        # Using a generator allows HF to auto-detect the new schema.
        try:
            new_dataset = datasets.Dataset.from_generator(gen)
            new_dataset.push_to_hub(target_id, token=self.token)
            return {"status": "success", "rows_processed": len(new_dataset)}
        except Exception as e:
            logger.error(e)
            raise e
            
    def preview_transform(self, dataset_id, config, split, recipe):
        conf = config if config != 'default' else None
        ds_stream = load_dataset(dataset_id, name=conf, split=split, streaming=True, token=self.token)
        processed = []
        for i, row in enumerate(ds_stream):
            if len(processed) >= 5: break
            trans_row = self._apply_transformations(row, recipe)
            if self._passes_filter(trans_row, recipe.get("filters", [])):
                processed.append(trans_row)
        return processed