File size: 4,170 Bytes
2797314
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d775b3
 
 
 
 
2797314
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
"""
Local Schema Extraction (No LLM)
Fast, cheap extraction of dataset metadata without sending to LLM.
"""

import polars as pl
from pathlib import Path
from typing import Dict, Any, Optional


def extract_schema_local(file_path: str, sample_rows: int = 5) -> Dict[str, Any]:
    """
    Extract dataset schema and basic stats locally without LLM.
    
    Returns:
        - column names and types
        - row/column counts
        - missing value counts
        - small sample for reference
        - memory usage
    """
    try:
        # Read with Polars (faster than pandas)
        if file_path.endswith('.csv'):
            df = pl.read_csv(file_path)
        elif file_path.endswith('.parquet'):
            df = pl.read_parquet(file_path)
        else:
            # Fallback to pandas
            import pandas as pd
            pdf = pd.read_csv(file_path)
            df = pl.from_pandas(pdf)
        
        # Basic metadata
        schema_info = {
            'file_path': file_path,
            'file_size_mb': round(Path(file_path).stat().st_size / (1024 * 1024), 2),
            'num_rows': df.shape[0],
            'num_columns': df.shape[1],
            'columns': {}
        }
        
        # Per-column metadata
        for col in df.columns:
            col_series = df[col]
            dtype_str = str(col_series.dtype)
            
            col_info = {
                'dtype': dtype_str,
                'missing_count': col_series.null_count(),
                'missing_pct': round(col_series.null_count() / len(col_series) * 100, 2),
                'unique_count': col_series.n_unique() if len(col_series) < 100000 else None  # Skip for huge datasets
            }
            
            # Type-specific stats (lightweight)
            if dtype_str in ['Int64', 'Float64', 'Int32', 'Float32']:
                try:
                    col_info['min'] = float(col_series.min())
                    col_info['max'] = float(col_series.max())
                    col_info['mean'] = float(col_series.mean())
                except:
                    pass
            
            schema_info['columns'][col] = col_info
        
        # Small sample for LLM context (only first few rows)
        sample_data = df.head(sample_rows).to_dicts()
        schema_info['sample_rows'] = sample_data
        
        # Categorize columns
        schema_info['numeric_columns'] = [
            col for col, info in schema_info['columns'].items()
            if 'Int' in info['dtype'] or 'Float' in info['dtype']
        ]
        schema_info['categorical_columns'] = [
            col for col, info in schema_info['columns'].items()
            if info['dtype'] in ['Utf8', 'String'] or (
                info.get('unique_count') is not None and 
                info.get('unique_count') < 50 and 
                col not in schema_info['numeric_columns']
            )
        ]
        schema_info['datetime_columns'] = [
            col for col, info in schema_info['columns'].items()
            if 'Date' in info['dtype'] or 'Time' in info['dtype']
        ]
        
        return schema_info
        
    except Exception as e:
        return {
            'error': f"Failed to extract schema: {str(e)}",
            'file_path': file_path
        }


def infer_task_type(target_column: str, schema_info: Dict[str, Any]) -> Optional[str]:
    """
    Infer ML task type from target column without LLM.
    """
    if not target_column or target_column not in schema_info.get('columns', {}):
        return None
    
    target_info = schema_info['columns'][target_column]
    
    # Numeric with many unique values → regression
    if target_info['dtype'] in ['Int64', 'Float64', 'Int32', 'Float32']:
        unique_count = target_info.get('unique_count')
        if unique_count and unique_count > 20:
            return 'regression'
        elif unique_count and unique_count <= 10:
            return 'classification'
    
    # Categorical or low cardinality → classification
    if target_info['dtype'] in ['Utf8', 'String'] or target_info.get('unique_count', 0) <= 20:
        return 'classification'
    
    return None