Spaces:
Running
Running
| """ | |
| LegitData Bridge - Integration layer between demo_prep and legitdata. | |
| This module provides a clean interface for demo_prep to use legitdata | |
| for realistic data generation instead of faker-based scripts. | |
| Usage: | |
| from legitdata_bridge import populate_demo_data, preview_demo_data | |
| # Preview what will be generated | |
| preview = preview_demo_data(ddl, url, use_case) | |
| # Actually populate the database | |
| success, message, results = populate_demo_data( | |
| ddl_content=ddl, | |
| company_url=url, | |
| use_case=use_case, | |
| schema_name=schema, | |
| size="medium" | |
| ) | |
| """ | |
| import os | |
| import sys | |
| import signal | |
| from dataclasses import dataclass | |
| # Ignore SIGPIPE to prevent broken pipe errors from crashing the process | |
| # This can happen when stdout is closed during long-running operations | |
| try: | |
| signal.signal(signal.SIGPIPE, signal.SIG_DFL) | |
| except (AttributeError, ValueError): | |
| pass # SIGPIPE not available on Windows | |
| # Add legitdata_project to path so we can import legitdata | |
| _legitdata_path = os.path.join(os.path.dirname(__file__), 'legitdata_project') | |
| if _legitdata_path not in sys.path: | |
| sys.path.insert(0, _legitdata_path) | |
| from typing import Optional, Callable, Tuple, Dict, Any | |
| from dotenv import load_dotenv | |
| from llm_config import ( | |
| build_openai_chat_token_kwargs, | |
| is_openai_model_name, | |
| resolve_model_name, | |
| ) | |
| from llm_client_factory import create_openai_client | |
| load_dotenv() | |
| class _CompatTextBlock: | |
| text: str | |
| class _CompatResponse: | |
| def __init__(self, text: str): | |
| self.content = [_CompatTextBlock(text=text or "")] | |
| class _OpenAIAnthropicCompatMessages: | |
| """Compatibility shim to satisfy LegitData's anthropic_client.messages.create() calls.""" | |
| def __init__(self, openai_client, configured_model: str): | |
| self._openai_client = openai_client | |
| self._configured_model = configured_model | |
| def create(self, model=None, max_tokens=1000, messages=None, **kwargs): | |
| target_model = resolve_model_name(model) | |
| # LegitData currently hardcodes Claude model IDs internally. | |
| # We intentionally override those with the selected OpenAI model. | |
| if not is_openai_model_name(target_model): | |
| target_model = resolve_model_name(self._configured_model) | |
| if not is_openai_model_name(target_model): | |
| raise ValueError( | |
| f"LegitData requires an OpenAI GPT/Codex model. Received: '{target_model}'." | |
| ) | |
| payload = { | |
| "model": target_model, | |
| "messages": messages or [], | |
| } | |
| temperature = kwargs.get("temperature") | |
| if temperature is not None: | |
| payload["temperature"] = temperature | |
| payload.update(build_openai_chat_token_kwargs(target_model, max_tokens)) | |
| response = self._openai_client.chat.completions.create(**payload) | |
| text = response.choices[0].message.content if response.choices else "" | |
| return _CompatResponse(text) | |
| class OpenAICompatClient: | |
| """Minimal client with Anthropic-like shape used by LegitData internals.""" | |
| def __init__(self, configured_model: str): | |
| self._configured_model = resolve_model_name(configured_model) | |
| self._openai_client = create_openai_client(timeout=60, max_retries=3) | |
| self.messages = _OpenAIAnthropicCompatMessages(self._openai_client, configured_model) | |
| def get_legitdata_llm_client(llm_model: str): | |
| """Create the LLM client for LegitData using selected app model.""" | |
| model_name = resolve_model_name(llm_model) | |
| if not model_name: | |
| raise ValueError("LegitData requires llm_model from settings, but none was provided.") | |
| if not is_openai_model_name(model_name): | |
| raise ValueError( | |
| f"LegitData data generation only supports OpenAI GPT/Codex models right now. Got: '{model_name}'." | |
| ) | |
| return OpenAICompatClient(model_name) | |
| def get_snowflake_connection_params_safe() -> dict: | |
| """ | |
| Get Snowflake connection params from the app auth module. | |
| """ | |
| from snowflake_auth import get_snowflake_connection_params | |
| return get_snowflake_connection_params() | |
| class KeyPairSnowflakeWriter: | |
| """ | |
| Snowflake writer that uses the app's key-pair authentication. | |
| Drop-in replacement for legitdata's SnowflakeWriter. | |
| """ | |
| def __init__(self, schema_name: str): | |
| self.schema_name = schema_name | |
| self.connection = None | |
| self.cursor = None | |
| def __enter__(self): | |
| self.connect() | |
| return self | |
| def __exit__(self, exc_type, exc_val, exc_tb): | |
| self.disconnect() | |
| return False | |
| def connect(self) -> None: | |
| """Establish Snowflake connection using app's key-pair auth.""" | |
| import snowflake.connector | |
| from snowflake_auth import get_snowflake_connection_params | |
| conn_params = get_snowflake_connection_params() | |
| database = conn_params.get('database', 'DEMOBUILD') | |
| # Add keep-alive to prevent session timeout during long operations | |
| self.connection = snowflake.connector.connect(**conn_params) | |
| self.cursor = self.connection.cursor() | |
| # Explicitly set database and schema context (quote schema name for special chars) | |
| self.cursor.execute(f"USE DATABASE {database}") | |
| self.cursor.execute(f'USE SCHEMA "{self.schema_name}"') | |
| print(f"Connected to Snowflake: {database}.{self.schema_name}") | |
| def disconnect(self) -> None: | |
| """Close Snowflake connection.""" | |
| if self.cursor: | |
| self.cursor.close() | |
| self.cursor = None | |
| if self.connection: | |
| self.connection.close() | |
| self.connection = None | |
| print("Disconnected from Snowflake") | |
| def insert_rows(self, table_name: str, columns: list, rows: list, batch_size: int = 1000) -> int: | |
| """Insert rows into a Snowflake table.""" | |
| from decimal import Decimal | |
| from datetime import datetime, date | |
| if not rows: | |
| return 0 | |
| if not self.cursor: | |
| raise RuntimeError("Not connected to database") | |
| # Get column metadata for this table to know constraints | |
| column_info = self._get_column_info(table_name) | |
| # Filter out identity columns (let Snowflake auto-generate) | |
| filtered_columns = [] | |
| filtered_indices = [] | |
| for i, col in enumerate(columns): | |
| has_values = any(row[i] is not None for row in rows[:10]) | |
| if has_values: | |
| filtered_columns.append(col) | |
| filtered_indices.append(i) | |
| if not filtered_columns: | |
| print(f"Warning: No columns with values to insert for {table_name}") | |
| return 0 | |
| col_list = ', '.join(filtered_columns) | |
| placeholders = ', '.join(['%s'] * len(filtered_columns)) | |
| sql = f"INSERT INTO {table_name} ({col_list}) VALUES ({placeholders})" | |
| def convert_value(value, col_name): | |
| import math | |
| if value is None: | |
| return None | |
| if isinstance(value, bool): | |
| return int(value) | |
| if isinstance(value, Decimal): | |
| value = float(value) | |
| # NaN and Inf are not valid SQL values — replace with NULL | |
| if isinstance(value, float) and (math.isnan(value) or math.isinf(value)): | |
| return None | |
| if isinstance(value, datetime): | |
| return value.strftime('%Y-%m-%d %H:%M:%S') | |
| if isinstance(value, date): | |
| return value.strftime('%Y-%m-%d') | |
| # Apply constraints based on column metadata | |
| col_upper = col_name.upper() | |
| if col_upper in column_info: | |
| info = column_info[col_upper] | |
| data_type = info.get('type', '') | |
| # Truncate strings to fit VARCHAR length | |
| if 'VARCHAR' in data_type or 'TEXT' in data_type: | |
| max_len = info.get('length', 255) | |
| if isinstance(value, str) and len(value) > max_len: | |
| value = value[:max_len] | |
| # Clamp numbers to fit DECIMAL precision | |
| elif 'NUMBER' in data_type or 'DECIMAL' in data_type or 'NUMERIC' in data_type or 'INT' in data_type or 'FLOAT' in data_type: | |
| # If a string landed in a numeric column (AI misclassification), coerce or null it | |
| if isinstance(value, str): | |
| try: | |
| value = float(value) | |
| except (ValueError, TypeError): | |
| return None # Can't coerce — use NULL rather than crash | |
| precision = info.get('precision', 38) | |
| scale = info.get('scale', 0) | |
| if isinstance(value, (int, float)): | |
| # Max value for given precision/scale | |
| max_val = 10 ** (precision - scale) - (10 ** -scale) | |
| min_val = -max_val | |
| value = max(min_val, min(max_val, value)) | |
| # Round to scale | |
| if scale > 0: | |
| value = round(value, scale) | |
| return value | |
| total_inserted = 0 | |
| first_error = None | |
| for batch_start in range(0, len(rows), batch_size): | |
| batch_end = min(batch_start + batch_size, len(rows)) | |
| batch = rows[batch_start:batch_end] | |
| filtered_batch = [] | |
| for row in batch: | |
| filtered_row = [convert_value(row[i], filtered_columns[j]) for j, i in enumerate(filtered_indices)] | |
| filtered_batch.append(filtered_row) | |
| try: | |
| self.cursor.executemany(sql, filtered_batch) | |
| total_inserted += len(filtered_batch) | |
| except Exception as e: | |
| if first_error is None: | |
| first_error = str(e) | |
| print(f"Error inserting batch into {table_name}: {e}") | |
| for row in filtered_batch: | |
| try: | |
| self.cursor.execute(sql, row) | |
| total_inserted += 1 | |
| except Exception as row_error: | |
| col_val_pairs = list(zip(filtered_columns, row)) | |
| print(f" Failed row ({type(row_error).__name__}: {row_error})") | |
| print(f" Values: {col_val_pairs[:6]}") | |
| self.connection.commit() | |
| print(f"Inserted {total_inserted} rows into {table_name}") | |
| if total_inserted == 0 and len(rows) > 0: | |
| raise RuntimeError( | |
| f"All {len(rows)} rows failed to insert into {table_name}. " | |
| f"First error: {first_error}" | |
| ) | |
| return total_inserted | |
| def _get_column_info(self, table_name: str) -> dict: | |
| """Get column metadata (type, length, precision) for a table.""" | |
| if not self.cursor: | |
| return {} | |
| try: | |
| self.cursor.execute(f"DESCRIBE TABLE {table_name}") | |
| results = self.cursor.fetchall() | |
| column_info = {} | |
| for row in results: | |
| col_name = row[0].upper() | |
| data_type = row[1].upper() | |
| info = {'type': data_type} | |
| # Parse VARCHAR(n) | |
| if 'VARCHAR' in data_type: | |
| import re | |
| match = re.search(r'\((\d+)\)', data_type) | |
| if match: | |
| info['length'] = int(match.group(1)) | |
| else: | |
| info['length'] = 255 # Default | |
| # Parse NUMBER(p,s) or DECIMAL(p,s) | |
| elif 'NUMBER' in data_type or 'DECIMAL' in data_type or 'NUMERIC' in data_type: | |
| import re | |
| match = re.search(r'\((\d+),?\s*(\d*)\)', data_type) | |
| if match: | |
| info['precision'] = int(match.group(1)) | |
| info['scale'] = int(match.group(2)) if match.group(2) else 0 | |
| else: | |
| info['precision'] = 38 | |
| info['scale'] = 0 | |
| column_info[col_name] = info | |
| return column_info | |
| except Exception as e: | |
| print(f"Warning: Could not get column info for {table_name}: {e}") | |
| return {} | |
| def truncate_table(self, table_name: str) -> None: | |
| """Truncate a Snowflake table.""" | |
| if not self.cursor: | |
| raise RuntimeError("Not connected to database") | |
| self.cursor.execute(f"TRUNCATE TABLE {table_name}") | |
| self.connection.commit() | |
| print(f"Truncated table {table_name}") | |
| class DemoLegitGenerator: | |
| """ | |
| LegitGenerator wrapper that uses the app's key-pair auth for writing. | |
| Wraps the real LegitGenerator but overrides database writing. | |
| """ | |
| def __init__(self, schema_name: str, base_generator): | |
| self._schema_name = schema_name | |
| self._gen = base_generator | |
| self._comment_choices: dict = {} # col_name_upper -> "choice:val1,val2,..." | |
| def _extract_comment_choices(self, ddl_content: str) -> dict: | |
| """Extract 'Values: val1 | val2 | ...' from COMMENT annotations in DDL. | |
| LegitData strips these before parsing, so we capture them here first | |
| and inject them as generation strategies after _classify_columns() runs. | |
| """ | |
| import re | |
| choices = {} | |
| # Matches: col_name DATATYPE... COMMENT 'Values: val1 | val2 | val3' | |
| pattern = r'(\w+)\s+\w+[^,\n]*COMMENT\s+[\'"]Values:\s*([^\'\"]+)[\'"]' | |
| for m in re.finditer(pattern, ddl_content, re.IGNORECASE): | |
| col_name = m.group(1).upper() | |
| values = [v.strip() for v in m.group(2).split('|') if v.strip()] | |
| if values: | |
| choices[col_name] = "choice:" + ",".join(values) | |
| print(f" [DemoPrep] COMMENT values found for {col_name}: {values[:4]}{'...' if len(values) > 4 else ''}") | |
| return choices | |
| def load_ddl(self, ddl_content: str): | |
| """Extract COMMENT value hints before passing to base generator (which strips them).""" | |
| self._comment_choices = self._extract_comment_choices(ddl_content) | |
| return self._gen.load_ddl(ddl_content) | |
| def generate(self, size: str = "medium", row_counts=None, truncate_first: bool = True): | |
| """Generate data and write using our key-pair auth writer.""" | |
| from legitdata.config import GenerationConfig | |
| if not self._gen.schema: | |
| raise RuntimeError("No schema loaded. Call load_ddl() first.") | |
| config = GenerationConfig( | |
| url=self._gen.url, | |
| use_case=self._gen.use_case, | |
| size=size, | |
| row_counts=row_counts | |
| ) | |
| print("\n=== Step 1: Building Company Context ===") | |
| self._gen._build_context() | |
| print("\n=== Step 2: Classifying Columns ===") | |
| self._gen._classify_columns() | |
| # Override generation strategy for COMMENT-annotated columns. | |
| # Must run AFTER _classify_columns() so our values take precedence. | |
| # Only inject choice: strategy on string/varchar columns — never numeric. | |
| _NUMERIC_TYPES = {'int', 'integer', 'bigint', 'smallint', 'tinyint', | |
| 'number', 'numeric', 'decimal', 'float', 'double', | |
| 'real', 'money', 'byteint'} | |
| if self._comment_choices: | |
| for table in self._gen.schema.tables: | |
| for col in table.columns: | |
| strategy = self._comment_choices.get(col.name.upper()) | |
| if strategy: | |
| col_type = (col.data_type or '').lower().split('(')[0].strip() | |
| if col_type in _NUMERIC_TYPES: | |
| print(f" [DemoPrep] Skipping COMMENT choices for {table.name}.{col.name} — numeric column ({col.data_type})") | |
| continue | |
| col.generation_strategy = strategy | |
| print(f" [DemoPrep] Injected COMMENT choices into {table.name}.{col.name}") | |
| print("\n=== Step 3: Generating Data ===") | |
| generated_data = {} | |
| for table in self._gen.schema.get_dependency_order(): | |
| num_rows = config.get_table_row_count(table.name, table.is_fact_table) | |
| print(f"\nGenerating {num_rows} rows for {table.name}...") | |
| rows = self._gen._generate_table_data(table, num_rows) | |
| generated_data[table.name] = rows | |
| # Register PK values for FK references | |
| self._gen._register_pk_values(table, rows) | |
| print("\n=== Step 4: Writing to Database ===") | |
| results = self._write_to_database(generated_data, truncate_first) | |
| print("\n=== Generation Complete ===") | |
| total = sum(results.values()) | |
| print(f"Total rows inserted: {total}") | |
| return results | |
| def _write_to_database(self, data: dict, truncate_first: bool) -> dict: | |
| """Write using KeyPairSnowflakeWriter (app's auth).""" | |
| results = {} | |
| writer = KeyPairSnowflakeWriter(self._schema_name) | |
| with writer: | |
| for table in self._gen.schema.get_dependency_order(): | |
| rows = data.get(table.name, []) | |
| if not rows: | |
| continue | |
| columns = [col.name for col in table.columns] | |
| list_rows = [] | |
| for row in rows: | |
| list_rows.append([row.get(col) for col in columns]) | |
| if truncate_first: | |
| writer.truncate_table(table.name) | |
| count = writer.insert_rows(table.name, columns, list_rows) | |
| results[table.name] = count | |
| return results | |
| def schema(self): | |
| return self._gen.schema | |
| def get_legitdata_generator( | |
| company_url: str, | |
| use_case: str, | |
| schema_name: str, | |
| llm_model: str, | |
| dry_run: bool = False, | |
| cache_enabled: bool = False # Disabled during development | |
| ): | |
| """ | |
| Factory function to create a configured LegitGenerator. | |
| Handles all the setup that demo_prep needs: | |
| - Snowflake connection from existing auth | |
| - LLM client derived from selected model in settings | |
| - Deterministic writer integration | |
| Args: | |
| company_url: Company website URL for context | |
| use_case: Use case description (e.g., "Retail Analytics") | |
| schema_name: Target Snowflake schema | |
| dry_run: If True, don't actually write to database | |
| cache_enabled: If True, cache AI responses (disabled by default during dev) | |
| Returns: | |
| Configured LegitGenerator instance | |
| """ | |
| from legitdata import LegitGenerator | |
| # Get connection params | |
| conn_params = get_snowflake_connection_params_safe() | |
| # Build connection string for legitdata | |
| # Note: For key-pair auth, we'll need to pass params differently | |
| # This is handled in the custom writer below | |
| connection_string = ( | |
| f"snowflake://{conn_params.get('user', '')}:" | |
| f"{conn_params.get('password', '')}@" | |
| f"{conn_params.get('account', '')}/" | |
| f"{conn_params.get('database', '')}/" | |
| f"{schema_name}" | |
| f"?warehouse={conn_params.get('warehouse', '')}" | |
| ) | |
| # Build LLM client from selected model (OpenAI GPT path) | |
| client = get_legitdata_llm_client(llm_model) | |
| web_search_fn = None | |
| # Create base generator (connection_string is a placeholder - we use our own writer) | |
| base_gen = LegitGenerator( | |
| url=company_url, | |
| use_case=use_case, | |
| connection_string=connection_string, | |
| anthropic_client=client, | |
| web_search_fn=web_search_fn, | |
| dry_run=True, # Always dry_run for base - we handle writing ourselves | |
| cache_enabled=cache_enabled, | |
| cache_dir=".legitdata_cache" | |
| ) | |
| # Wrap with our custom writer that uses key-pair auth | |
| return DemoLegitGenerator( | |
| schema_name=schema_name, | |
| base_generator=base_gen | |
| ) | |
| def populate_demo_data( | |
| ddl_content: str, | |
| company_url: str, | |
| use_case: str, | |
| schema_name: str, | |
| llm_model: str, | |
| size: str = "medium", | |
| progress_callback: Optional[Callable[[str], None]] = None, | |
| truncate_first: bool = True, | |
| session_logger=None, | |
| ) -> Tuple[bool, str, Dict[str, int]]: | |
| """ | |
| Main entry point for demo data population using LegitData. | |
| Replaces the old execute_population_script() function with | |
| direct AI-powered data generation. | |
| Args: | |
| ddl_content: SQL DDL from schema generation phase | |
| company_url: Company website URL | |
| use_case: Use case description | |
| schema_name: Target Snowflake schema name | |
| size: Size preset - "small", "medium", "large", or "xl" | |
| - small: 100 fact rows, 20 dim rows | |
| - medium: 1,000 fact rows, 100 dim rows | |
| - large: 10,000 fact rows, 500 dim rows | |
| - xl: 100,000 fact rows, 500 dim rows | |
| progress_callback: Optional callback for progress updates | |
| truncate_first: If True, truncate tables before inserting | |
| Returns: | |
| Tuple of (success: bool, message: str, results: dict) | |
| - success: True if population completed | |
| - message: Human-readable status message | |
| - results: Dict mapping table names to row counts | |
| Example: | |
| success, message, results = populate_demo_data( | |
| ddl_content=ddl, | |
| company_url="https://amazon.com", | |
| use_case="Retail Analytics", | |
| schema_name="DEMO_AMZ_123", | |
| size="medium" | |
| ) | |
| if success: | |
| print(f"Populated {sum(results.values())} rows") | |
| else: | |
| print(f"Failed: {message}") | |
| """ | |
| _slog = session_logger | |
| import time as _time | |
| _t_populate = _slog.log_start("populate") if _slog else None | |
| def log(msg: str): | |
| try: | |
| print(msg) | |
| except BrokenPipeError: | |
| pass # Ignore broken pipe on stdout - process can continue | |
| if progress_callback: | |
| try: | |
| progress_callback(msg) | |
| except BrokenPipeError: | |
| pass | |
| try: | |
| log("Initializing LegitData generator...") | |
| log(f" Company: {company_url}") | |
| log(f" Use Case: {use_case}") | |
| log(f" Schema: {schema_name}") | |
| log(f" Size: {size}") | |
| gen = get_legitdata_generator( | |
| company_url=company_url, | |
| use_case=use_case, | |
| schema_name=schema_name, | |
| llm_model=llm_model, | |
| dry_run=False | |
| ) | |
| log("Parsing DDL schema...") | |
| schema = gen.load_ddl(ddl_content) | |
| log(f" Found {len(schema.dimension_tables)} dimension tables") | |
| log(f" Found {len(schema.fact_tables)} fact tables") | |
| log(f"Generating {size} dataset...") | |
| log(" This may take a few minutes...") | |
| # Attempt 1 | |
| results = {} | |
| first_attempt_error = None | |
| try: | |
| results = gen.generate(size=size, truncate_first=truncate_first) | |
| except RuntimeError as e: | |
| first_attempt_error = str(e) | |
| log(f"⚠️ First population attempt failed: {first_attempt_error}") | |
| if _slog: | |
| _slog.log_verbose("populate", "first attempt failed", error=first_attempt_error) | |
| # Check all tables for 0 rows even if no exception | |
| fact_table_names = {t.name for t in schema.fact_tables} | |
| empty_facts = [t for t in fact_table_names if results.get(t, 0) == 0] | |
| if not first_attempt_error and empty_facts: | |
| first_attempt_error = f"Fact table(s) empty after first attempt: {', '.join(empty_facts)}" | |
| log(f"⚠️ {first_attempt_error}") | |
| # Attempt 2 if anything failed | |
| if first_attempt_error: | |
| log("🔄 Retrying data population (attempt 2 of 2)...") | |
| try: | |
| results = gen.generate(size=size, truncate_first=True) | |
| except RuntimeError as retry_e: | |
| if _slog: | |
| _slog.log_verbose("populate", "retry attempt failed", error=str(retry_e)) | |
| raise RuntimeError( | |
| f"Population failed after 2 attempts. " | |
| f"Attempt 1: {first_attempt_error} | Attempt 2: {retry_e}" | |
| ) | |
| # Final check — report all table counts and fail if fact tables still empty | |
| log("\nFinal table row counts:") | |
| for table_name, count in sorted(results.items()): | |
| status = "✅" if count > 0 else "❌" | |
| log(f" {status} {table_name}: {count:,} rows") | |
| empty_facts_final = [t for t in fact_table_names if results.get(t, 0) == 0] | |
| if empty_facts_final: | |
| raise RuntimeError( | |
| f"Fact table(s) still empty after 2 attempts: {', '.join(empty_facts_final)}. " | |
| f"First attempt error: {first_attempt_error}" | |
| ) | |
| # Format results | |
| total_rows = sum(results.values()) | |
| table_lines = [f" - {table}: {count:,} rows" for table, count in results.items()] | |
| message = f"""LegitData Population Complete | |
| Generated {total_rows:,} total rows: | |
| {chr(10).join(table_lines)} | |
| Data contextually generated for: | |
| Company: {company_url} | |
| Use Case: {use_case} | |
| """ | |
| log(message) | |
| if _slog and _t_populate is not None: | |
| _slog.log_end("populate", _t_populate, tables=len(results), total_rows=total_rows) | |
| return True, message, results | |
| except Exception as e: | |
| import traceback | |
| error_tb = traceback.format_exc() | |
| error_msg = f"""LegitData Population Failed | |
| Error: {str(e)} | |
| Troubleshooting: | |
| - Check Snowflake connection credentials | |
| - Verify schema '{schema_name}' exists | |
| - Ensure DDL is valid | |
| Traceback: | |
| {error_tb} | |
| """ | |
| log(error_msg) | |
| if _slog and _t_populate is not None: | |
| _slog.log_end("populate", _t_populate, error=str(e)) | |
| return False, error_msg, {} | |
| def preview_demo_data( | |
| ddl_content: str, | |
| company_url: str, | |
| use_case: str, | |
| llm_model: str, | |
| num_rows: int = 5 | |
| ) -> Dict[str, list]: | |
| """ | |
| Preview what data will be generated (dry run, no database writes). | |
| Useful for showing users what kind of data will be created | |
| before actually populating the database. | |
| Args: | |
| ddl_content: SQL DDL from schema generation | |
| company_url: Company website URL | |
| use_case: Use case description | |
| num_rows: Number of sample rows to generate per table | |
| Returns: | |
| Dict mapping table names to lists of row dictionaries | |
| Example: | |
| preview = preview_demo_data(ddl, "https://amazon.com", "Retail Analytics") | |
| for table, rows in preview.items(): | |
| print(f"{table}: {len(rows)} sample rows") | |
| print(rows[0]) # First row | |
| """ | |
| gen = get_legitdata_generator( | |
| company_url=company_url, | |
| use_case=use_case, | |
| schema_name="preview", # Doesn't matter for dry run | |
| llm_model=llm_model, | |
| dry_run=True, | |
| cache_enabled=False # Disabled during development | |
| ) | |
| gen.load_ddl(ddl_content) | |
| return gen.preview(num_rows=num_rows) | |
| def format_preview_for_display(preview: Dict[str, list], max_rows: int = 3) -> str: | |
| """ | |
| Format preview data as markdown for UI display. | |
| Args: | |
| preview: Output from preview_demo_data() | |
| max_rows: Maximum rows to show per table | |
| Returns: | |
| Markdown-formatted string | |
| """ | |
| import json | |
| output = "## Data Preview\n\n" | |
| output += "Sample data that will be generated:\n\n" | |
| for table, rows in preview.items(): | |
| output += f"### {table}\n\n" | |
| if not rows: | |
| output += "_No data generated_\n\n" | |
| continue | |
| # Show sample rows | |
| for i, row in enumerate(rows[:max_rows]): | |
| output += f"**Row {i+1}:**\n" | |
| output += "```json\n" | |
| # Format row, converting non-JSON types | |
| formatted_row = {} | |
| for k, v in row.items(): | |
| if v is None: | |
| formatted_row[k] = None | |
| elif hasattr(v, 'isoformat'): # datetime | |
| formatted_row[k] = v.isoformat() | |
| else: | |
| formatted_row[k] = str(v) if not isinstance(v, (str, int, float, bool)) else v | |
| output += json.dumps(formatted_row, indent=2) | |
| output += "\n```\n\n" | |
| if len(rows) > max_rows: | |
| output += f"_...and {len(rows) - max_rows} more rows_\n\n" | |
| return output | |
| def get_size_preset_info(size: str) -> dict: | |
| """ | |
| Get information about a size preset. | |
| Args: | |
| size: Preset name (small/medium/large/xl) | |
| Returns: | |
| Dict with fact_rows, dim_rows, description | |
| """ | |
| from legitdata.config import SIZE_PRESETS | |
| preset = SIZE_PRESETS.get(size, SIZE_PRESETS["medium"]) | |
| return { | |
| "name": preset.name, | |
| "fact_rows": preset.fact_rows, | |
| "dim_rows": preset.dim_rows, | |
| "description": preset.description | |
| } | |
| # Compatibility function - wraps old interface | |
| def execute_population_with_legitdata( | |
| python_code_or_ddl: str, | |
| schema_name: str, | |
| company_url: str = None, | |
| use_case: str = None, | |
| llm_model: str = None, | |
| skip_modifications: bool = False | |
| ) -> Tuple[bool, str]: | |
| """ | |
| Drop-in replacement for execute_population_script(). | |
| This function provides backward compatibility with the old interface | |
| while using legitdata under the hood. | |
| Args: | |
| python_code_or_ddl: Either the old Python code (ignored) or DDL | |
| schema_name: Target schema name | |
| company_url: Company URL (required for legitdata) | |
| use_case: Use case (required for legitdata) | |
| skip_modifications: Ignored (for compatibility) | |
| Returns: | |
| Tuple of (success: bool, message: str) | |
| Note: | |
| For full functionality, use populate_demo_data() directly. | |
| """ | |
| if not company_url: | |
| company_url = "https://example.com" | |
| if not use_case: | |
| use_case = "General Analytics" | |
| # Try to detect if this is DDL or Python code | |
| if "CREATE TABLE" in python_code_or_ddl.upper(): | |
| ddl_content = python_code_or_ddl | |
| else: | |
| # It's Python code - we can't use it with legitdata | |
| # Fall back to trying to extract DDL from somewhere | |
| return False, ( | |
| "LegitData requires DDL, not Python code. " | |
| "Please pass schema_generation_results instead of data_population_results." | |
| ) | |
| success, message, results = populate_demo_data( | |
| ddl_content=ddl_content, | |
| company_url=company_url, | |
| use_case=use_case, | |
| schema_name=schema_name, | |
| llm_model=llm_model, | |
| size="medium" | |
| ) | |
| return success, message | |
| # Quick test | |
| if __name__ == "__main__": | |
| print("LegitData Bridge Module") | |
| print("=" * 40) | |
| # Test size presets | |
| for size in ["small", "medium", "large", "xl"]: | |
| info = get_size_preset_info(size) | |
| print(f"{size}: {info['fact_rows']} facts, {info['dim_rows']} dims - {info['description']}") | |