demoprep / legitdata_project /legitdata /analyzer /column_classifier.py
mikeboone's picture
feat: March 2026 sprint β€” new vision merge, pipeline improvements, settings refactor
5ac32c1
"""Column classifier for determining data sourcing strategy."""
import json
import re
from datetime import datetime, timedelta
from typing import Optional
from ..ddl.models import Schema, Table, Column, ColumnClassification
from ..config import COLUMN_HINTS
from .context_builder import CompanyContext
from ..domain import infer_semantic_type, is_business_categorical
def _years_ago(dt: datetime, years: int) -> datetime:
"""Return datetime shifted back by full years (leap-safe)."""
try:
return dt.replace(year=dt.year - years)
except ValueError:
return dt.replace(month=2, day=28, year=dt.year - years)
class ColumnClassifier:
"""Classifies columns to determine how their data should be sourced."""
def __init__(self, anthropic_client=None):
"""
Initialize classifier.
Args:
anthropic_client: Anthropic client for AI classification
"""
self.anthropic_client = anthropic_client
def classify_schema(
self,
schema: Schema,
context: CompanyContext,
use_case: str
) -> dict[str, dict[str, dict]]:
"""
Classify all columns in a schema.
Args:
schema: Parsed database schema
context: Company context
use_case: Analytics use case
Returns:
Dict mapping table_name -> column_name -> {classification, strategy}
"""
if self.anthropic_client:
return self._classify_with_ai(schema, context, use_case)
else:
return self._classify_with_heuristics(schema, context, use_case)
def _classify_with_ai(
self,
schema: Schema,
context: CompanyContext,
use_case: str
) -> dict[str, dict[str, dict]]:
"""Use AI to classify columns."""
# Build schema description
schema_desc = self._schema_to_description(schema)
prompt = f"""Analyze this database schema and classify each column for realistic data generation.
Company Context:
{context.to_prompt()}
Use Case: {use_case}
Schema:
{schema_desc}
For each column, determine:
1. classification: How to source the data
- SEARCH_REAL: Use web search for real-world data (products, brands, companies, etc.)
- AI_GEN: Generate contextually appropriate data (regions, categories, descriptions)
- GENERIC: Use calculated/random data (IDs, dates, amounts, booleans)
2. strategy: Specific instructions
- For SEARCH_REAL: The search query to use (e.g., "consumer electronics brands")
- For AI_GEN: Description of what to generate (e.g., "US geographic regions for distribution")
- For GENERIC: The faker type or calculation (e.g., "date_between:2020-01-01,2024-12-31" or "random_int:1,1000")
Return as JSON:
{{
"TABLE_NAME": {{
"column_name": {{
"classification": "SEARCH_REAL|AI_GEN|GENERIC",
"strategy": "specific instructions"
}}
}}
}}
Important rules:
- Primary key and ID columns are always GENERIC
- Foreign key columns are always GENERIC (they reference other tables)
- Columns ending in _key, _id, _date, _at are typically GENERIC
- Product names, brand names, company names should be SEARCH_REAL
- Region, category, segment, tier, type should be AI_GEN
- Amounts, quantities, prices, ratings should be GENERIC
- Boolean flags (is_*, has_*) should be GENERIC
Return ONLY valid JSON, no other text."""
try:
response = self.anthropic_client.messages.create(
model="claude-sonnet-4-20250514",
max_tokens=4000,
messages=[{"role": "user", "content": prompt}]
)
json_str = response.content[0].text.strip()
# Handle markdown code blocks
if json_str.startswith('```'):
json_str = re.sub(r'^```\w*\n?', '', json_str)
json_str = re.sub(r'\n?```$', '', json_str)
result = json.loads(json_str)
# Normalize and validate
return self._normalize_classifications(result, schema)
except Exception as e:
print(f"Warning: AI classification failed: {e}")
return self._classify_with_heuristics(schema, context, use_case)
def _classify_with_heuristics(
self,
schema: Schema,
context: CompanyContext,
use_case: str
) -> dict[str, dict[str, dict]]:
"""Classify columns using pattern-based heuristics."""
result = {}
for table in schema.tables:
result[table.name] = {}
fk_columns = {fk.column_name.upper() for fk in table.foreign_keys}
for column in table.columns:
classification, strategy = self._classify_column_heuristic(
column, table, fk_columns, context, use_case
)
result[table.name][column.name] = {
"classification": classification.value,
"strategy": strategy
}
return result
def _classify_column_heuristic(
self,
column: Column,
table: Table,
fk_columns: set[str],
context: CompanyContext,
use_case: str
) -> tuple[ColumnClassification, str]:
"""Classify a single column using heuristics."""
col_name = column.name.lower()
col_upper = column.name.upper()
semantic = infer_semantic_type(column.name, column.data_type, table.name, use_case)
# Primary keys and identity columns
if column.is_primary_key or column.is_identity:
return ColumnClassification.GENERIC, "identity"
# Foreign keys
if col_upper in fk_columns:
return ColumnClassification.GENERIC, "foreign_key"
# Quality-first: keep business categorical columns contextual.
if is_business_categorical(semantic):
prompt = self._generate_ai_prompt(col_name, table.name, context, use_case)
return ColumnClassification.AI_GEN, prompt
# Check patterns for SEARCH_REAL
for pattern in COLUMN_HINTS["search_real_patterns"]:
if re.search(pattern, col_name, re.IGNORECASE):
query = self._generate_search_query(col_name, context, use_case)
return ColumnClassification.SEARCH_REAL, query
# Check patterns for AI_GEN
for pattern in COLUMN_HINTS["ai_gen_patterns"]:
if re.search(pattern, col_name, re.IGNORECASE):
prompt = self._generate_ai_prompt(col_name, table.name, context, use_case)
return ColumnClassification.AI_GEN, prompt
# Check patterns for GENERIC
for pattern in COLUMN_HINTS["generic_patterns"]:
if re.search(pattern, col_name, re.IGNORECASE):
faker_type = self._infer_faker_type(column)
return ColumnClassification.GENERIC, faker_type
# Default based on data type
return self._classify_by_type(column)
def _generate_search_query(
self,
col_name: str,
context: CompanyContext,
use_case: str
) -> str:
"""Generate a search query for a SEARCH_REAL column."""
industry = context.industry.lower()
if 'product' in col_name.lower():
return f"{industry} products list"
elif 'brand' in col_name.lower():
return f"{industry} brands companies"
elif 'seller' in col_name.lower() or 'vendor' in col_name.lower():
return f"{industry} vendors suppliers companies"
elif 'company' in col_name.lower() or 'customer' in col_name.lower():
return f"{industry} companies businesses"
else:
return f"{industry} {col_name.replace('_', ' ')}"
def _generate_ai_prompt(
self,
col_name: str,
table_name: str,
context: CompanyContext,
use_case: str
) -> str:
"""Generate a prompt for an AI_GEN column."""
col_clean = col_name.replace('_', ' ').title()
return f"Generate realistic {col_clean} values for {context.company_name} ({context.industry}). Use case: {use_case}."
def _infer_faker_type(self, column: Column) -> str:
"""Infer the faker type from column name and type."""
col_name = column.name.lower()
data_type = column.data_type.upper()
# BOOLEAN COLUMNS - Check FIRST before phone/contact heuristics
# so names like MobileTransaction BOOLEAN don't get treated as
# phone numbers just because they contain "mobile".
if 'BOOLEAN' in data_type or data_type in ('BOOL', 'BIT'):
return "boolean:0.8"
# BOOLEAN NAME PATTERNS - still keep these ahead of contact fields.
if col_name.startswith('is_') or col_name.startswith('has_') or col_name.endswith('_flag'):
return "boolean:0.5"
# Email columns
if 'email' in col_name:
return "email"
# Phone columns (but NOT is_mobile - that's boolean, checked above)
if 'phone' in col_name or 'mobile' in col_name or 'tel' in col_name:
return "phone"
# Identifier-like short codes should not be treated as monetary values.
if (
'tax_id' in col_name
or 'taxid' in col_name
or 'last4' in col_name
or col_name.endswith('_id_last4')
):
if any(t in data_type for t in ('CHAR', 'TEXT', 'STRING', 'VARCHAR')):
return "random_string:4"
return "random_int:1000,9999"
# Name columns
if col_name in ('first_name', 'firstname', 'fname', 'given_name'):
return "first_name"
if col_name in ('last_name', 'lastname', 'lname', 'surname'):
return "last_name"
if col_name in ('full_name', 'fullname', 'name', 'user_name'):
return "name"
if col_name in ('customer_name', 'account_name', 'client_name', 'company_name', 'organization_name'):
return "company"
# Address columns
if 'address' in col_name and 'email' not in col_name:
return "address"
if col_name in ('city', 'town'):
return "city"
if col_name in ('state', 'province'):
return "state"
if col_name in ('country', 'nation'):
return "country"
if 'zip' in col_name or 'postal' in col_name:
return "zipcode"
# Birth date columns - must be evaluated before generic date checks
# to avoid impossible ages from generic recent-date generation.
today_dt = datetime.now()
today = today_dt.strftime('%Y-%m-%d')
is_birth_col = (
'birthdate' in col_name or
'birth_date' in col_name or
'date_of_birth' in col_name or
col_name == 'dob' or
col_name.endswith('_dob')
)
if is_birth_col:
normalized_today = today_dt.replace(hour=0, minute=0, second=0, microsecond=0)
oldest = _years_ago(normalized_today, 95).strftime('%Y-%m-%d')
youngest = _years_ago(normalized_today, 18).strftime('%Y-%m-%d')
return f"date_between:{oldest},{youngest}"
# Date columns - use current date for end range.
# Guard against financial names like *_AMOUNT_TO_DATE on numeric fields.
financial_to_date_name = (
('to_date' in col_name or col_name.endswith('todate'))
and any(
token in col_name
for token in (
'amount',
'balance',
'cost',
'price',
'revenue',
'sales',
'fee',
'tax',
'total',
'recovery',
'payment',
)
)
)
is_name_date = ('date' in col_name) and not financial_to_date_name
if is_name_date or data_type in ('DATE', 'TIMESTAMP', 'DATETIME'):
if 'created' in col_name or 'registration' in col_name:
return f"date_between:2020-01-01,{today}"
elif 'launch' in col_name or 'open' in col_name:
return f"date_between:2015-01-01,{today}"
else:
return f"date_between:2024-01-01,{today}"
# Boolean columns
if 'BOOLEAN' in data_type or col_name.startswith(('is_', 'has_')):
return "boolean:0.8" # 80% true by default
# Numeric columns
if 'quantity' in col_name:
return "random_int:1,100"
elif 'discount' in col_name:
return "random_decimal:0.00,200.00"
elif 'shipping' in col_name:
return "random_decimal:2.00,150.00"
elif 'tax' in col_name:
return "random_decimal:1.00,250.00"
elif 'revenue' in col_name:
return "random_decimal:10.00,10000.00"
elif 'amount' in col_name:
return "random_decimal:5.00,5000.00"
elif 'cost' in col_name or 'fee' in col_name or 'price' in col_name:
return "random_decimal:1.00,500.00"
elif 'rating' in col_name:
return "random_decimal:1.0,5.0"
elif 'minutes' in col_name or 'time' in col_name:
return "random_int:5,240"
elif any(x in col_name for x in ['count', 'total', 'number']):
return "random_int:0,1000"
# ID/Key columns - but check data_type first
if col_name.endswith('_id') or col_name.endswith('_key'):
# If numeric, generate numeric IDs not UUIDs
if any(t in data_type for t in ('INT', 'NUMBER', 'NUMERIC', 'DECIMAL', 'BIGINT', 'SMALLINT')):
return "random_int:1,1000"
return "uuid"
# Default for unknown
if 'INT' in data_type:
return "random_int:1,1000"
elif 'DECIMAL' in data_type or 'FLOAT' in data_type:
return "random_decimal:0.00,1000.00"
elif 'VARCHAR' in data_type or 'TEXT' in data_type:
return "random_string:10"
return "null"
def _classify_by_type(self, column: Column) -> tuple[ColumnClassification, str]:
"""Classify column by its data type."""
faker_type = self._infer_faker_type(column)
return ColumnClassification.GENERIC, faker_type
def _schema_to_description(self, schema: Schema) -> str:
"""Convert schema to text description for AI."""
lines = []
for table in schema.tables:
table_type = "FACT" if table.is_fact_table else "DIMENSION"
lines.append(f"\n{table.name} ({table_type}):")
for col in table.columns:
pk = " [PK]" if col.is_primary_key else ""
fk = ""
for fk_ref in table.foreign_keys:
if fk_ref.column_name.upper() == col.name.upper():
fk = f" [FK -> {fk_ref.references_table}]"
break
lines.append(f" - {col.name}: {col.data_type}{pk}{fk}")
return '\n'.join(lines)
def _normalize_classifications(
self,
result: dict,
schema: Schema
) -> dict[str, dict[str, dict]]:
"""Normalize and validate AI classification results."""
normalized = {}
# Handle case where AI returned null/None
if result is None:
result = {}
for table in schema.tables:
table_key = None
# Find matching table (case-insensitive)
for key in result.keys():
if key.upper() == table.name.upper():
table_key = key
break
normalized[table.name] = {}
table_data = result.get(table_key, {}) if table_key else {}
if table_data is None:
table_data = {} # Handle AI returning null for a table
for column in table.columns:
col_key = None
for key in table_data.keys():
if key.upper() == column.name.upper():
col_key = key
break
# OVERRIDE: Force certain columns to GENERIC regardless of AI classification
# These are columns where Faker does a better job than AI
forced_generic = self._should_force_generic(column.name, column.data_type)
if forced_generic:
normalized[table.name][column.name] = {
"classification": "GENERIC",
"strategy": forced_generic
}
elif col_key and col_key in table_data:
col_data = table_data[col_key]
normalized[table.name][column.name] = {
"classification": col_data.get("classification", "GENERIC"),
"strategy": col_data.get("strategy", "unknown")
}
else:
# Default to GENERIC if not found
normalized[table.name][column.name] = {
"classification": "GENERIC",
"strategy": self._infer_faker_type(column)
}
return normalized
def _should_force_generic(self, col_name: str, data_type: str = None) -> str | None:
"""
Check if column should be forced to GENERIC classification.
Returns faker strategy if yes, None if no.
These are columns where Faker produces better results than AI generation.
"""
col_lower = col_name.lower()
semantic = infer_semantic_type(col_name, data_type)
# Never force business categorical columns to generic;
# these need domain-aware contextual generation.
if is_business_categorical(semantic):
return None
# DEBUG: Log all checks
print(f" [FORCE_CHECK] Column: '{col_name}' (lower: '{col_lower}')")
# BOOLEAN COLUMNS - Check FIRST before phone/contact heuristics so
# boolean names containing "mobile" don't become phone numbers.
if data_type:
dt_upper = data_type.upper()
if 'BOOLEAN' in dt_upper or dt_upper in ('BOOL', 'BIT'):
print(f" β†’ FORCED to 'boolean:0.5' (boolean data type)")
return "boolean:0.5"
# BOOLEAN NAME PATTERNS
if col_lower.startswith('is_') or col_lower.startswith('has_') or col_lower.endswith('_flag'):
print(f" β†’ FORCED to 'boolean:0.5' (boolean pattern)")
return "boolean:0.5"
# Birth date columns must always produce realistic adult ages.
is_birth_col = (
'birthdate' in col_lower or
'birth_date' in col_lower or
'date_of_birth' in col_lower or
col_lower == 'dob' or
col_lower.endswith('_dob')
)
if is_birth_col:
today_dt = datetime.now()
oldest = _years_ago(today_dt, 95).strftime('%Y-%m-%d')
youngest = _years_ago(today_dt, 18).strftime('%Y-%m-%d')
return f"date_between:{oldest},{youngest}"
# Check if data type is numeric
is_numeric = False
if data_type:
dt_upper = data_type.upper()
is_numeric = any(t in dt_upper for t in ('INT', 'NUMBER', 'NUMERIC', 'DECIMAL', 'BIGINT', 'SMALLINT'))
# ID columns - should be IDs, not random words
# Match: order_id, orderid, customer_id, transaction_id, etc.
if col_lower.endswith('_id') or col_lower.endswith('id'):
# But not 'paid', 'said', 'valid', etc.
if col_lower in ('paid', 'said', 'valid', 'invalid', 'void'):
return None
# If column is numeric, generate numeric IDs instead of UUIDs
if is_numeric:
print(f" β†’ FORCED to 'random_int:1,1000' (numeric ID column)")
return "random_int:1,1000"
print(f" β†’ FORCED to 'uuid' (ID column)")
return "uuid"
# Key columns - similar to IDs
if col_lower.endswith('_key') or col_lower.endswith('key'):
if col_lower in ('key',): # standalone 'key' might be something else
return None
# If column is numeric, generate numeric keys instead of UUIDs
if is_numeric:
print(f" β†’ FORCED to 'random_int:1,1000' (numeric key column)")
return "random_int:1,1000"
return "uuid"
# Email - Faker does this perfectly
if 'email' in col_lower:
print(f" β†’ FORCED to 'email'")
return "email"
# Phone numbers (but NOT is_mobile - that's boolean, checked above)
if 'phone' in col_lower or 'mobile' in col_lower or 'tel' in col_lower:
return "phone"
# First name - exact matches
if col_lower in ('first_name', 'firstname', 'fname', 'given_name', 'givenname'):
print(f" β†’ FORCED to 'first_name'")
return "first_name"
# Last name - exact matches
if col_lower in ('last_name', 'lastname', 'lname', 'surname', 'family_name', 'familyname'):
print(f" β†’ FORCED to 'last_name'")
return "last_name"
# Full name - any column that's clearly a person's name
if col_lower in ('full_name', 'fullname', 'name', 'user_name', 'username'):
return "name"
if col_lower in ('customer_name', 'account_name', 'client_name', 'company_name', 'organization_name'):
return "company"
# Person name columns that don't match exact list above
# rep_name, manager_name, agent_name, employee_name, contact_name, etc.
if col_lower.endswith('_name') or col_lower.endswith('name'):
# Exclude non-person names (these are handled elsewhere as SEARCH_REAL or AI_GEN)
non_person = ('product_name', 'productname', 'company_name', 'companyname',
'account_name', 'accountname', 'customer_name', 'customername',
'client_name', 'clientname', 'organization_name', 'organizationname',
'brand_name', 'brandname', 'store_name', 'storename',
'warehouse_name', 'warehousename', 'center_name', 'centername',
'campaign_name', 'campaignname', 'table_name', 'tablename',
'column_name', 'columnname', 'schema_name', 'schemaname',
'file_name', 'filename', 'host_name', 'hostname',
'category_name', 'categoryname', 'holiday_name', 'holidayname',
'branch_name', 'branchname', 'department_name', 'departmentname',
'region_name', 'regionname', 'city_name', 'cityname',
'month_name', 'monthname', 'quarter_name', 'quartername', 'day_name', 'dayname')
if col_lower not in non_person:
print(f" β†’ FORCED to 'name' (person name pattern: {col_lower})")
return "name"
# Address - contains 'address' (but not email_address)
if 'address' in col_lower and 'email' not in col_lower:
return "address"
# City - contains 'city' anywhere (center_city, shipping_city, etc.)
if 'city' in col_lower or col_lower == 'town':
print(f" β†’ FORCED to 'city'")
return "city"
# State/Province - contains 'state' or 'province'
if 'state' in col_lower or 'province' in col_lower:
print(f" β†’ FORCED to 'state'")
return "state"
# Zip/Postal - contains 'zip' or 'postal'
if 'zip' in col_lower or 'postal' in col_lower:
return "zipcode"
# Country - contains 'country'
if 'country' in col_lower:
print(f" β†’ FORCED to 'country'")
return "country"
# URL/Website
# Avoid false positives like "hourly" containing "url".
if (
re.search(r'(^|[_\-])url([_\-]|$)', col_lower)
or 'website' in col_lower
):
return "url"
return None