Spaces:
Running
Running
| """Column classifier for determining data sourcing strategy.""" | |
| import json | |
| import re | |
| from datetime import datetime, timedelta | |
| from typing import Optional | |
| from ..ddl.models import Schema, Table, Column, ColumnClassification | |
| from ..config import COLUMN_HINTS | |
| from .context_builder import CompanyContext | |
| from ..domain import infer_semantic_type, is_business_categorical | |
| def _years_ago(dt: datetime, years: int) -> datetime: | |
| """Return datetime shifted back by full years (leap-safe).""" | |
| try: | |
| return dt.replace(year=dt.year - years) | |
| except ValueError: | |
| return dt.replace(month=2, day=28, year=dt.year - years) | |
| class ColumnClassifier: | |
| """Classifies columns to determine how their data should be sourced.""" | |
| def __init__(self, anthropic_client=None): | |
| """ | |
| Initialize classifier. | |
| Args: | |
| anthropic_client: Anthropic client for AI classification | |
| """ | |
| self.anthropic_client = anthropic_client | |
| def classify_schema( | |
| self, | |
| schema: Schema, | |
| context: CompanyContext, | |
| use_case: str | |
| ) -> dict[str, dict[str, dict]]: | |
| """ | |
| Classify all columns in a schema. | |
| Args: | |
| schema: Parsed database schema | |
| context: Company context | |
| use_case: Analytics use case | |
| Returns: | |
| Dict mapping table_name -> column_name -> {classification, strategy} | |
| """ | |
| if self.anthropic_client: | |
| return self._classify_with_ai(schema, context, use_case) | |
| else: | |
| return self._classify_with_heuristics(schema, context, use_case) | |
| def _classify_with_ai( | |
| self, | |
| schema: Schema, | |
| context: CompanyContext, | |
| use_case: str | |
| ) -> dict[str, dict[str, dict]]: | |
| """Use AI to classify columns.""" | |
| # Build schema description | |
| schema_desc = self._schema_to_description(schema) | |
| prompt = f"""Analyze this database schema and classify each column for realistic data generation. | |
| Company Context: | |
| {context.to_prompt()} | |
| Use Case: {use_case} | |
| Schema: | |
| {schema_desc} | |
| For each column, determine: | |
| 1. classification: How to source the data | |
| - SEARCH_REAL: Use web search for real-world data (products, brands, companies, etc.) | |
| - AI_GEN: Generate contextually appropriate data (regions, categories, descriptions) | |
| - GENERIC: Use calculated/random data (IDs, dates, amounts, booleans) | |
| 2. strategy: Specific instructions | |
| - For SEARCH_REAL: The search query to use (e.g., "consumer electronics brands") | |
| - For AI_GEN: Description of what to generate (e.g., "US geographic regions for distribution") | |
| - For GENERIC: The faker type or calculation (e.g., "date_between:2020-01-01,2024-12-31" or "random_int:1,1000") | |
| Return as JSON: | |
| {{ | |
| "TABLE_NAME": {{ | |
| "column_name": {{ | |
| "classification": "SEARCH_REAL|AI_GEN|GENERIC", | |
| "strategy": "specific instructions" | |
| }} | |
| }} | |
| }} | |
| Important rules: | |
| - Primary key and ID columns are always GENERIC | |
| - Foreign key columns are always GENERIC (they reference other tables) | |
| - Columns ending in _key, _id, _date, _at are typically GENERIC | |
| - Product names, brand names, company names should be SEARCH_REAL | |
| - Region, category, segment, tier, type should be AI_GEN | |
| - Amounts, quantities, prices, ratings should be GENERIC | |
| - Boolean flags (is_*, has_*) should be GENERIC | |
| Return ONLY valid JSON, no other text.""" | |
| try: | |
| response = self.anthropic_client.messages.create( | |
| model="claude-sonnet-4-20250514", | |
| max_tokens=4000, | |
| messages=[{"role": "user", "content": prompt}] | |
| ) | |
| json_str = response.content[0].text.strip() | |
| # Handle markdown code blocks | |
| if json_str.startswith('```'): | |
| json_str = re.sub(r'^```\w*\n?', '', json_str) | |
| json_str = re.sub(r'\n?```$', '', json_str) | |
| result = json.loads(json_str) | |
| # Normalize and validate | |
| return self._normalize_classifications(result, schema) | |
| except Exception as e: | |
| print(f"Warning: AI classification failed: {e}") | |
| return self._classify_with_heuristics(schema, context, use_case) | |
| def _classify_with_heuristics( | |
| self, | |
| schema: Schema, | |
| context: CompanyContext, | |
| use_case: str | |
| ) -> dict[str, dict[str, dict]]: | |
| """Classify columns using pattern-based heuristics.""" | |
| result = {} | |
| for table in schema.tables: | |
| result[table.name] = {} | |
| fk_columns = {fk.column_name.upper() for fk in table.foreign_keys} | |
| for column in table.columns: | |
| classification, strategy = self._classify_column_heuristic( | |
| column, table, fk_columns, context, use_case | |
| ) | |
| result[table.name][column.name] = { | |
| "classification": classification.value, | |
| "strategy": strategy | |
| } | |
| return result | |
| def _classify_column_heuristic( | |
| self, | |
| column: Column, | |
| table: Table, | |
| fk_columns: set[str], | |
| context: CompanyContext, | |
| use_case: str | |
| ) -> tuple[ColumnClassification, str]: | |
| """Classify a single column using heuristics.""" | |
| col_name = column.name.lower() | |
| col_upper = column.name.upper() | |
| semantic = infer_semantic_type(column.name, column.data_type, table.name, use_case) | |
| # Primary keys and identity columns | |
| if column.is_primary_key or column.is_identity: | |
| return ColumnClassification.GENERIC, "identity" | |
| # Foreign keys | |
| if col_upper in fk_columns: | |
| return ColumnClassification.GENERIC, "foreign_key" | |
| # Quality-first: keep business categorical columns contextual. | |
| if is_business_categorical(semantic): | |
| prompt = self._generate_ai_prompt(col_name, table.name, context, use_case) | |
| return ColumnClassification.AI_GEN, prompt | |
| # Check patterns for SEARCH_REAL | |
| for pattern in COLUMN_HINTS["search_real_patterns"]: | |
| if re.search(pattern, col_name, re.IGNORECASE): | |
| query = self._generate_search_query(col_name, context, use_case) | |
| return ColumnClassification.SEARCH_REAL, query | |
| # Check patterns for AI_GEN | |
| for pattern in COLUMN_HINTS["ai_gen_patterns"]: | |
| if re.search(pattern, col_name, re.IGNORECASE): | |
| prompt = self._generate_ai_prompt(col_name, table.name, context, use_case) | |
| return ColumnClassification.AI_GEN, prompt | |
| # Check patterns for GENERIC | |
| for pattern in COLUMN_HINTS["generic_patterns"]: | |
| if re.search(pattern, col_name, re.IGNORECASE): | |
| faker_type = self._infer_faker_type(column) | |
| return ColumnClassification.GENERIC, faker_type | |
| # Default based on data type | |
| return self._classify_by_type(column) | |
| def _generate_search_query( | |
| self, | |
| col_name: str, | |
| context: CompanyContext, | |
| use_case: str | |
| ) -> str: | |
| """Generate a search query for a SEARCH_REAL column.""" | |
| industry = context.industry.lower() | |
| if 'product' in col_name.lower(): | |
| return f"{industry} products list" | |
| elif 'brand' in col_name.lower(): | |
| return f"{industry} brands companies" | |
| elif 'seller' in col_name.lower() or 'vendor' in col_name.lower(): | |
| return f"{industry} vendors suppliers companies" | |
| elif 'company' in col_name.lower() or 'customer' in col_name.lower(): | |
| return f"{industry} companies businesses" | |
| else: | |
| return f"{industry} {col_name.replace('_', ' ')}" | |
| def _generate_ai_prompt( | |
| self, | |
| col_name: str, | |
| table_name: str, | |
| context: CompanyContext, | |
| use_case: str | |
| ) -> str: | |
| """Generate a prompt for an AI_GEN column.""" | |
| col_clean = col_name.replace('_', ' ').title() | |
| return f"Generate realistic {col_clean} values for {context.company_name} ({context.industry}). Use case: {use_case}." | |
| def _infer_faker_type(self, column: Column) -> str: | |
| """Infer the faker type from column name and type.""" | |
| col_name = column.name.lower() | |
| data_type = column.data_type.upper() | |
| # BOOLEAN COLUMNS - Check FIRST before phone/contact heuristics | |
| # so names like MobileTransaction BOOLEAN don't get treated as | |
| # phone numbers just because they contain "mobile". | |
| if 'BOOLEAN' in data_type or data_type in ('BOOL', 'BIT'): | |
| return "boolean:0.8" | |
| # BOOLEAN NAME PATTERNS - still keep these ahead of contact fields. | |
| if col_name.startswith('is_') or col_name.startswith('has_') or col_name.endswith('_flag'): | |
| return "boolean:0.5" | |
| # Email columns | |
| if 'email' in col_name: | |
| return "email" | |
| # Phone columns (but NOT is_mobile - that's boolean, checked above) | |
| if 'phone' in col_name or 'mobile' in col_name or 'tel' in col_name: | |
| return "phone" | |
| # Identifier-like short codes should not be treated as monetary values. | |
| if ( | |
| 'tax_id' in col_name | |
| or 'taxid' in col_name | |
| or 'last4' in col_name | |
| or col_name.endswith('_id_last4') | |
| ): | |
| if any(t in data_type for t in ('CHAR', 'TEXT', 'STRING', 'VARCHAR')): | |
| return "random_string:4" | |
| return "random_int:1000,9999" | |
| # Name columns | |
| if col_name in ('first_name', 'firstname', 'fname', 'given_name'): | |
| return "first_name" | |
| if col_name in ('last_name', 'lastname', 'lname', 'surname'): | |
| return "last_name" | |
| if col_name in ('full_name', 'fullname', 'name', 'user_name'): | |
| return "name" | |
| if col_name in ('customer_name', 'account_name', 'client_name', 'company_name', 'organization_name'): | |
| return "company" | |
| # Address columns | |
| if 'address' in col_name and 'email' not in col_name: | |
| return "address" | |
| if col_name in ('city', 'town'): | |
| return "city" | |
| if col_name in ('state', 'province'): | |
| return "state" | |
| if col_name in ('country', 'nation'): | |
| return "country" | |
| if 'zip' in col_name or 'postal' in col_name: | |
| return "zipcode" | |
| # Birth date columns - must be evaluated before generic date checks | |
| # to avoid impossible ages from generic recent-date generation. | |
| today_dt = datetime.now() | |
| today = today_dt.strftime('%Y-%m-%d') | |
| is_birth_col = ( | |
| 'birthdate' in col_name or | |
| 'birth_date' in col_name or | |
| 'date_of_birth' in col_name or | |
| col_name == 'dob' or | |
| col_name.endswith('_dob') | |
| ) | |
| if is_birth_col: | |
| normalized_today = today_dt.replace(hour=0, minute=0, second=0, microsecond=0) | |
| oldest = _years_ago(normalized_today, 95).strftime('%Y-%m-%d') | |
| youngest = _years_ago(normalized_today, 18).strftime('%Y-%m-%d') | |
| return f"date_between:{oldest},{youngest}" | |
| # Date columns - use current date for end range. | |
| # Guard against financial names like *_AMOUNT_TO_DATE on numeric fields. | |
| financial_to_date_name = ( | |
| ('to_date' in col_name or col_name.endswith('todate')) | |
| and any( | |
| token in col_name | |
| for token in ( | |
| 'amount', | |
| 'balance', | |
| 'cost', | |
| 'price', | |
| 'revenue', | |
| 'sales', | |
| 'fee', | |
| 'tax', | |
| 'total', | |
| 'recovery', | |
| 'payment', | |
| ) | |
| ) | |
| ) | |
| is_name_date = ('date' in col_name) and not financial_to_date_name | |
| if is_name_date or data_type in ('DATE', 'TIMESTAMP', 'DATETIME'): | |
| if 'created' in col_name or 'registration' in col_name: | |
| return f"date_between:2020-01-01,{today}" | |
| elif 'launch' in col_name or 'open' in col_name: | |
| return f"date_between:2015-01-01,{today}" | |
| else: | |
| return f"date_between:2024-01-01,{today}" | |
| # Boolean columns | |
| if 'BOOLEAN' in data_type or col_name.startswith(('is_', 'has_')): | |
| return "boolean:0.8" # 80% true by default | |
| # Numeric columns | |
| if 'quantity' in col_name: | |
| return "random_int:1,100" | |
| elif 'discount' in col_name: | |
| return "random_decimal:0.00,200.00" | |
| elif 'shipping' in col_name: | |
| return "random_decimal:2.00,150.00" | |
| elif 'tax' in col_name: | |
| return "random_decimal:1.00,250.00" | |
| elif 'revenue' in col_name: | |
| return "random_decimal:10.00,10000.00" | |
| elif 'amount' in col_name: | |
| return "random_decimal:5.00,5000.00" | |
| elif 'cost' in col_name or 'fee' in col_name or 'price' in col_name: | |
| return "random_decimal:1.00,500.00" | |
| elif 'rating' in col_name: | |
| return "random_decimal:1.0,5.0" | |
| elif 'minutes' in col_name or 'time' in col_name: | |
| return "random_int:5,240" | |
| elif any(x in col_name for x in ['count', 'total', 'number']): | |
| return "random_int:0,1000" | |
| # ID/Key columns - but check data_type first | |
| if col_name.endswith('_id') or col_name.endswith('_key'): | |
| # If numeric, generate numeric IDs not UUIDs | |
| if any(t in data_type for t in ('INT', 'NUMBER', 'NUMERIC', 'DECIMAL', 'BIGINT', 'SMALLINT')): | |
| return "random_int:1,1000" | |
| return "uuid" | |
| # Default for unknown | |
| if 'INT' in data_type: | |
| return "random_int:1,1000" | |
| elif 'DECIMAL' in data_type or 'FLOAT' in data_type: | |
| return "random_decimal:0.00,1000.00" | |
| elif 'VARCHAR' in data_type or 'TEXT' in data_type: | |
| return "random_string:10" | |
| return "null" | |
| def _classify_by_type(self, column: Column) -> tuple[ColumnClassification, str]: | |
| """Classify column by its data type.""" | |
| faker_type = self._infer_faker_type(column) | |
| return ColumnClassification.GENERIC, faker_type | |
| def _schema_to_description(self, schema: Schema) -> str: | |
| """Convert schema to text description for AI.""" | |
| lines = [] | |
| for table in schema.tables: | |
| table_type = "FACT" if table.is_fact_table else "DIMENSION" | |
| lines.append(f"\n{table.name} ({table_type}):") | |
| for col in table.columns: | |
| pk = " [PK]" if col.is_primary_key else "" | |
| fk = "" | |
| for fk_ref in table.foreign_keys: | |
| if fk_ref.column_name.upper() == col.name.upper(): | |
| fk = f" [FK -> {fk_ref.references_table}]" | |
| break | |
| lines.append(f" - {col.name}: {col.data_type}{pk}{fk}") | |
| return '\n'.join(lines) | |
| def _normalize_classifications( | |
| self, | |
| result: dict, | |
| schema: Schema | |
| ) -> dict[str, dict[str, dict]]: | |
| """Normalize and validate AI classification results.""" | |
| normalized = {} | |
| # Handle case where AI returned null/None | |
| if result is None: | |
| result = {} | |
| for table in schema.tables: | |
| table_key = None | |
| # Find matching table (case-insensitive) | |
| for key in result.keys(): | |
| if key.upper() == table.name.upper(): | |
| table_key = key | |
| break | |
| normalized[table.name] = {} | |
| table_data = result.get(table_key, {}) if table_key else {} | |
| if table_data is None: | |
| table_data = {} # Handle AI returning null for a table | |
| for column in table.columns: | |
| col_key = None | |
| for key in table_data.keys(): | |
| if key.upper() == column.name.upper(): | |
| col_key = key | |
| break | |
| # OVERRIDE: Force certain columns to GENERIC regardless of AI classification | |
| # These are columns where Faker does a better job than AI | |
| forced_generic = self._should_force_generic(column.name, column.data_type) | |
| if forced_generic: | |
| normalized[table.name][column.name] = { | |
| "classification": "GENERIC", | |
| "strategy": forced_generic | |
| } | |
| elif col_key and col_key in table_data: | |
| col_data = table_data[col_key] | |
| normalized[table.name][column.name] = { | |
| "classification": col_data.get("classification", "GENERIC"), | |
| "strategy": col_data.get("strategy", "unknown") | |
| } | |
| else: | |
| # Default to GENERIC if not found | |
| normalized[table.name][column.name] = { | |
| "classification": "GENERIC", | |
| "strategy": self._infer_faker_type(column) | |
| } | |
| return normalized | |
| def _should_force_generic(self, col_name: str, data_type: str = None) -> str | None: | |
| """ | |
| Check if column should be forced to GENERIC classification. | |
| Returns faker strategy if yes, None if no. | |
| These are columns where Faker produces better results than AI generation. | |
| """ | |
| col_lower = col_name.lower() | |
| semantic = infer_semantic_type(col_name, data_type) | |
| # Never force business categorical columns to generic; | |
| # these need domain-aware contextual generation. | |
| if is_business_categorical(semantic): | |
| return None | |
| # DEBUG: Log all checks | |
| print(f" [FORCE_CHECK] Column: '{col_name}' (lower: '{col_lower}')") | |
| # BOOLEAN COLUMNS - Check FIRST before phone/contact heuristics so | |
| # boolean names containing "mobile" don't become phone numbers. | |
| if data_type: | |
| dt_upper = data_type.upper() | |
| if 'BOOLEAN' in dt_upper or dt_upper in ('BOOL', 'BIT'): | |
| print(f" β FORCED to 'boolean:0.5' (boolean data type)") | |
| return "boolean:0.5" | |
| # BOOLEAN NAME PATTERNS | |
| if col_lower.startswith('is_') or col_lower.startswith('has_') or col_lower.endswith('_flag'): | |
| print(f" β FORCED to 'boolean:0.5' (boolean pattern)") | |
| return "boolean:0.5" | |
| # Birth date columns must always produce realistic adult ages. | |
| is_birth_col = ( | |
| 'birthdate' in col_lower or | |
| 'birth_date' in col_lower or | |
| 'date_of_birth' in col_lower or | |
| col_lower == 'dob' or | |
| col_lower.endswith('_dob') | |
| ) | |
| if is_birth_col: | |
| today_dt = datetime.now() | |
| oldest = _years_ago(today_dt, 95).strftime('%Y-%m-%d') | |
| youngest = _years_ago(today_dt, 18).strftime('%Y-%m-%d') | |
| return f"date_between:{oldest},{youngest}" | |
| # Check if data type is numeric | |
| is_numeric = False | |
| if data_type: | |
| dt_upper = data_type.upper() | |
| is_numeric = any(t in dt_upper for t in ('INT', 'NUMBER', 'NUMERIC', 'DECIMAL', 'BIGINT', 'SMALLINT')) | |
| # ID columns - should be IDs, not random words | |
| # Match: order_id, orderid, customer_id, transaction_id, etc. | |
| if col_lower.endswith('_id') or col_lower.endswith('id'): | |
| # But not 'paid', 'said', 'valid', etc. | |
| if col_lower in ('paid', 'said', 'valid', 'invalid', 'void'): | |
| return None | |
| # If column is numeric, generate numeric IDs instead of UUIDs | |
| if is_numeric: | |
| print(f" β FORCED to 'random_int:1,1000' (numeric ID column)") | |
| return "random_int:1,1000" | |
| print(f" β FORCED to 'uuid' (ID column)") | |
| return "uuid" | |
| # Key columns - similar to IDs | |
| if col_lower.endswith('_key') or col_lower.endswith('key'): | |
| if col_lower in ('key',): # standalone 'key' might be something else | |
| return None | |
| # If column is numeric, generate numeric keys instead of UUIDs | |
| if is_numeric: | |
| print(f" β FORCED to 'random_int:1,1000' (numeric key column)") | |
| return "random_int:1,1000" | |
| return "uuid" | |
| # Email - Faker does this perfectly | |
| if 'email' in col_lower: | |
| print(f" β FORCED to 'email'") | |
| return "email" | |
| # Phone numbers (but NOT is_mobile - that's boolean, checked above) | |
| if 'phone' in col_lower or 'mobile' in col_lower or 'tel' in col_lower: | |
| return "phone" | |
| # First name - exact matches | |
| if col_lower in ('first_name', 'firstname', 'fname', 'given_name', 'givenname'): | |
| print(f" β FORCED to 'first_name'") | |
| return "first_name" | |
| # Last name - exact matches | |
| if col_lower in ('last_name', 'lastname', 'lname', 'surname', 'family_name', 'familyname'): | |
| print(f" β FORCED to 'last_name'") | |
| return "last_name" | |
| # Full name - any column that's clearly a person's name | |
| if col_lower in ('full_name', 'fullname', 'name', 'user_name', 'username'): | |
| return "name" | |
| if col_lower in ('customer_name', 'account_name', 'client_name', 'company_name', 'organization_name'): | |
| return "company" | |
| # Person name columns that don't match exact list above | |
| # rep_name, manager_name, agent_name, employee_name, contact_name, etc. | |
| if col_lower.endswith('_name') or col_lower.endswith('name'): | |
| # Exclude non-person names (these are handled elsewhere as SEARCH_REAL or AI_GEN) | |
| non_person = ('product_name', 'productname', 'company_name', 'companyname', | |
| 'account_name', 'accountname', 'customer_name', 'customername', | |
| 'client_name', 'clientname', 'organization_name', 'organizationname', | |
| 'brand_name', 'brandname', 'store_name', 'storename', | |
| 'warehouse_name', 'warehousename', 'center_name', 'centername', | |
| 'campaign_name', 'campaignname', 'table_name', 'tablename', | |
| 'column_name', 'columnname', 'schema_name', 'schemaname', | |
| 'file_name', 'filename', 'host_name', 'hostname', | |
| 'category_name', 'categoryname', 'holiday_name', 'holidayname', | |
| 'branch_name', 'branchname', 'department_name', 'departmentname', | |
| 'region_name', 'regionname', 'city_name', 'cityname', | |
| 'month_name', 'monthname', 'quarter_name', 'quartername', 'day_name', 'dayname') | |
| if col_lower not in non_person: | |
| print(f" β FORCED to 'name' (person name pattern: {col_lower})") | |
| return "name" | |
| # Address - contains 'address' (but not email_address) | |
| if 'address' in col_lower and 'email' not in col_lower: | |
| return "address" | |
| # City - contains 'city' anywhere (center_city, shipping_city, etc.) | |
| if 'city' in col_lower or col_lower == 'town': | |
| print(f" β FORCED to 'city'") | |
| return "city" | |
| # State/Province - contains 'state' or 'province' | |
| if 'state' in col_lower or 'province' in col_lower: | |
| print(f" β FORCED to 'state'") | |
| return "state" | |
| # Zip/Postal - contains 'zip' or 'postal' | |
| if 'zip' in col_lower or 'postal' in col_lower: | |
| return "zipcode" | |
| # Country - contains 'country' | |
| if 'country' in col_lower: | |
| print(f" β FORCED to 'country'") | |
| return "country" | |
| # URL/Website | |
| # Avoid false positives like "hourly" containing "url". | |
| if ( | |
| re.search(r'(^|[_\-])url([_\-]|$)', col_lower) | |
| or 'website' in col_lower | |
| ): | |
| return "url" | |
| return None | |