Spaces:
Sleeping
Sleeping
| import os | |
| from sqlalchemy import create_engine, text | |
| from sqlalchemy.orm import sessionmaker, declarative_base | |
| from sqlalchemy.pool import NullPool | |
| # Try to use CockroachDB dialect if available | |
| try: | |
| import cockroachdb.sqlalchemy.dialect | |
| COCKROACHDB_AVAILABLE = True | |
| except ImportError: | |
| COCKROACHDB_AVAILABLE = False | |
| # Get database URL from environment variable | |
| # Default to SQLite for local development if not set | |
| ORIGINAL_DATABASE_URL = os.getenv( | |
| "DATABASE_URL", | |
| "sqlite:///./postgen.db" | |
| ) | |
| DATABASE_URL = ORIGINAL_DATABASE_URL | |
| # For CockroachDB, we need to handle SSL and connection pooling | |
| if ORIGINAL_DATABASE_URL.startswith("postgresql://") or ORIGINAL_DATABASE_URL.startswith("postgres://") or ORIGINAL_DATABASE_URL.startswith("cockroachdb://"): | |
| # Check if this is a CockroachDB connection (use original URL before modifications) | |
| is_cockroach = "cockroachlabs" in ORIGINAL_DATABASE_URL.lower() | |
| # CockroachDB connection - use NullPool to avoid connection issues | |
| # CockroachDB requires SSL, so we ensure sslmode is set | |
| # Use 'require' mode which uses SSL but doesn't require certificate file | |
| # For production with certificate, use 'verify-full' and provide sslrootcert | |
| cert_path = os.path.expanduser("~/.postgresql/root.crt") | |
| if "sslmode" not in DATABASE_URL: | |
| separator = "&" if "?" in DATABASE_URL else "?" | |
| # Use 'require' instead of 'verify-full' to work without certificate file | |
| # Still secure (uses SSL) but doesn't verify the certificate | |
| DATABASE_URL = f"{DATABASE_URL}{separator}sslmode=require" | |
| elif "sslmode=verify-full" in DATABASE_URL and not os.path.exists(cert_path): | |
| # If verify-full is set but cert file doesn't exist, change to require | |
| DATABASE_URL = DATABASE_URL.replace("sslmode=verify-full", "sslmode=require") | |
| print("β Certificate file not found, using sslmode=require instead of verify-full") | |
| # Use CockroachDB dialect if available and this is a CockroachDB connection | |
| if is_cockroach and COCKROACHDB_AVAILABLE: | |
| # Replace postgresql:// with cockroachdb:// to use CockroachDB dialect | |
| DATABASE_URL = DATABASE_URL.replace("postgresql://", "cockroachdb://", 1) | |
| DATABASE_URL = DATABASE_URL.replace("postgres://", "cockroachdb://", 1) | |
| # Configure engine | |
| engine = create_engine( | |
| DATABASE_URL, | |
| poolclass=NullPool, # CockroachDB works better with NullPool | |
| echo=False, # Set to True for SQL query debugging | |
| connect_args={} # No special connect args needed | |
| ) | |
| else: | |
| # SQLite for local development | |
| engine = create_engine( | |
| DATABASE_URL, | |
| connect_args={"check_same_thread": False} if "sqlite" in DATABASE_URL else {} | |
| ) | |
| SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) | |
| Base = declarative_base() | |
| def get_db(): | |
| """Dependency to get database session""" | |
| db = SessionLocal() | |
| try: | |
| yield db | |
| finally: | |
| db.close() | |
| def get_direct_psycopg2_connection(): | |
| """Get a direct psycopg2 connection bypassing SQLAlchemy version parsing""" | |
| try: | |
| import psycopg2 | |
| from urllib.parse import urlparse, parse_qs | |
| # Only for PostgreSQL/CockroachDB connections | |
| if not (ORIGINAL_DATABASE_URL.startswith("postgresql://") or | |
| ORIGINAL_DATABASE_URL.startswith("postgres://") or | |
| ORIGINAL_DATABASE_URL.startswith("cockroachdb://")): | |
| return None | |
| # Parse connection string | |
| url_for_parsing = ORIGINAL_DATABASE_URL.replace("cockroachdb://", "postgresql://") | |
| parsed = urlparse(url_for_parsing) | |
| dbname = parsed.path[1:] if parsed.path else "defaultdb" | |
| user = parsed.username | |
| password = parsed.password | |
| host = parsed.hostname | |
| port = parsed.port or 26257 | |
| # Get sslmode from query params | |
| params = parse_qs(parsed.query) | |
| sslmode_list = params.get('sslmode', ['require']) | |
| sslmode = sslmode_list[0] if sslmode_list else 'require' | |
| # Check if certificate file exists for verify-full mode | |
| cert_path = os.path.expanduser("~/.postgresql/root.crt") | |
| if sslmode == 'verify-full' and not os.path.exists(cert_path): | |
| # Change to require mode if certificate doesn't exist | |
| sslmode = 'require' | |
| # Create connection | |
| conn = psycopg2.connect( | |
| dbname=dbname, | |
| user=user, | |
| password=password, | |
| host=host, | |
| port=port, | |
| sslmode=sslmode | |
| ) | |
| return conn | |
| except Exception as e: | |
| print(f"Failed to create direct psycopg2 connection: {e}") | |
| return None | |
| def ensure_default_user(): | |
| """Ensure a default user (id=1) exists in the database""" | |
| try: | |
| conn = get_direct_psycopg2_connection() | |
| if not conn: | |
| return 1 # Return default ID if connection fails | |
| try: | |
| cursor = conn.cursor() | |
| # Check if user with id=1 exists | |
| cursor.execute("SELECT id FROM users WHERE id = 1") | |
| if cursor.fetchone(): | |
| cursor.close() | |
| conn.close() | |
| return 1 | |
| # Create default user if it doesn't exist | |
| # Try with ON CONFLICT first (PostgreSQL/CockroachDB) | |
| try: | |
| cursor.execute(""" | |
| INSERT INTO users (id, email, name, created_at) | |
| VALUES (1, 'default@postgen.app', 'Default User', NOW()) | |
| ON CONFLICT (id) DO NOTHING | |
| """) | |
| except Exception: | |
| # If ON CONFLICT fails, try without it (might be unique constraint on email) | |
| try: | |
| cursor.execute(""" | |
| INSERT INTO users (id, email, name, created_at) | |
| VALUES (1, 'default@postgen.app', 'Default User', NOW()) | |
| """) | |
| except Exception as insert_error: | |
| # User might already exist (race condition), check again | |
| cursor.execute("SELECT id FROM users WHERE id = 1 OR email = 'default@postgen.app' LIMIT 1") | |
| row = cursor.fetchone() | |
| if row: | |
| cursor.close() | |
| conn.close() | |
| return row[0] | |
| # If still fails, re-raise | |
| raise insert_error | |
| conn.commit() | |
| cursor.close() | |
| conn.close() | |
| return 1 | |
| except Exception as e: | |
| # If everything fails, try to get any existing user or return default | |
| try: | |
| cursor.execute("SELECT id FROM users LIMIT 1") | |
| row = cursor.fetchone() | |
| cursor.close() | |
| conn.close() | |
| if row: | |
| return row[0] | |
| except: | |
| if conn: | |
| conn.close() | |
| print(f"Warning: Could not ensure default user: {e}") | |
| return 1 # Return default ID as fallback | |
| except Exception as e: | |
| print(f"Error ensuring default user: {e}") | |
| return 1 # Return default ID as fallback | |
| def init_db(): | |
| """Initialize database tables""" | |
| try: | |
| from app.models import User, Integration, Asset, Post, Campaign | |
| # Try to create tables | |
| # For CockroachDB, version parsing may fail but connection still works | |
| try: | |
| Base.metadata.create_all(bind=engine) | |
| print("β Database tables created successfully") | |
| return True | |
| except Exception as create_error: | |
| error_str = str(create_error) | |
| # Check if it's a version parsing error (non-fatal for CockroachDB) | |
| if "Could not determine version" in error_str: | |
| # Version parsing failed, but CockroachDB connection works | |
| # Use psycopg2 directly to bypass SQLAlchemy's version parsing | |
| try: | |
| import psycopg2 | |
| from urllib.parse import urlparse, parse_qs | |
| # Parse connection string to get connection parameters | |
| # Handle both cockroachdb:// and postgresql:// schemes | |
| # Use original URL before any modifications | |
| url_for_parsing = ORIGINAL_DATABASE_URL.replace("cockroachdb://", "postgresql://") | |
| parsed = urlparse(url_for_parsing) | |
| dbname = parsed.path[1:] if parsed.path else "defaultdb" | |
| user = parsed.username | |
| password = parsed.password | |
| host = parsed.hostname | |
| port = parsed.port or 26257 | |
| # Get sslmode from query params (use require as default for CockroachDB) | |
| params = parse_qs(parsed.query) | |
| sslmode_list = params.get('sslmode', ['require']) | |
| sslmode = sslmode_list[0] if sslmode_list else 'require' | |
| # Connect directly with psycopg2 (bypasses SQLAlchemy version parsing) | |
| conn = psycopg2.connect( | |
| dbname=dbname, | |
| user=user, | |
| password=password, | |
| host=host, | |
| port=port, | |
| sslmode=sslmode | |
| ) | |
| cursor = conn.cursor() | |
| # Create tables using IF NOT EXISTS | |
| tables_sql = [ | |
| """CREATE TABLE IF NOT EXISTS users ( | |
| id SERIAL PRIMARY KEY, | |
| email VARCHAR UNIQUE, | |
| name VARCHAR, | |
| created_at TIMESTAMP DEFAULT NOW() | |
| )""", | |
| """CREATE TABLE IF NOT EXISTS integrations ( | |
| id SERIAL PRIMARY KEY, | |
| user_id INTEGER REFERENCES users(id), | |
| provider VARCHAR, | |
| access_token TEXT, | |
| refresh_token TEXT, | |
| expires_at TIMESTAMP, | |
| account_info JSONB, | |
| connected BOOLEAN DEFAULT FALSE, | |
| created_at TIMESTAMP DEFAULT NOW(), | |
| updated_at TIMESTAMP DEFAULT NOW() | |
| )""", | |
| """CREATE TABLE IF NOT EXISTS assets ( | |
| id SERIAL PRIMARY KEY, | |
| user_id INTEGER REFERENCES users(id), | |
| name VARCHAR, | |
| file_path VARCHAR, | |
| file_type VARCHAR, | |
| product_category VARCHAR, | |
| sub_category VARCHAR, | |
| size INTEGER, | |
| extra_metadata JSONB, | |
| extracted_content JSONB, | |
| analysis_status VARCHAR DEFAULT 'pending', | |
| analyzed_at TIMESTAMP, | |
| created_at TIMESTAMP DEFAULT NOW() | |
| )""", | |
| """CREATE TABLE IF NOT EXISTS posts ( | |
| id SERIAL PRIMARY KEY, | |
| user_id INTEGER REFERENCES users(id), | |
| title VARCHAR, | |
| content TEXT, | |
| post_type VARCHAR, | |
| product_category VARCHAR, | |
| scheduled_date TIMESTAMP, | |
| status VARCHAR, | |
| linkedin_post_id VARCHAR, | |
| canva_design_id VARCHAR, | |
| assets JSONB, | |
| extra_metadata JSONB, | |
| created_at TIMESTAMP DEFAULT NOW(), | |
| updated_at TIMESTAMP DEFAULT NOW() | |
| )""", | |
| """CREATE TABLE IF NOT EXISTS campaigns ( | |
| id SERIAL PRIMARY KEY, | |
| user_id INTEGER REFERENCES users(id), | |
| name VARCHAR, | |
| date_range_start TIMESTAMP, | |
| date_range_end TIMESTAMP, | |
| products JSONB, | |
| post_types JSONB, | |
| posts_per_week INTEGER, | |
| status VARCHAR, | |
| created_at TIMESTAMP DEFAULT NOW(), | |
| updated_at TIMESTAMP DEFAULT NOW() | |
| )""" | |
| ] | |
| for sql in tables_sql: | |
| cursor.execute(sql) | |
| conn.commit() | |
| # Add new columns to assets table if they don't exist (migration) | |
| # CockroachDB doesn't support ALTER TABLE in DO blocks, so we check first | |
| try: | |
| cursor = conn.cursor() | |
| # Check if columns exist and add them if they don't | |
| cursor.execute(""" | |
| SELECT column_name | |
| FROM information_schema.columns | |
| WHERE table_name='assets' AND column_name='extracted_content' | |
| """) | |
| if not cursor.fetchone(): | |
| cursor.execute("ALTER TABLE assets ADD COLUMN extracted_content JSONB") | |
| print("β Added extracted_content column") | |
| cursor.execute(""" | |
| SELECT column_name | |
| FROM information_schema.columns | |
| WHERE table_name='assets' AND column_name='analysis_status' | |
| """) | |
| if not cursor.fetchone(): | |
| cursor.execute("ALTER TABLE assets ADD COLUMN analysis_status VARCHAR DEFAULT 'pending'") | |
| print("β Added analysis_status column") | |
| cursor.execute(""" | |
| SELECT column_name | |
| FROM information_schema.columns | |
| WHERE table_name='assets' AND column_name='analyzed_at' | |
| """) | |
| if not cursor.fetchone(): | |
| cursor.execute("ALTER TABLE assets ADD COLUMN analyzed_at TIMESTAMP") | |
| print("β Added analyzed_at column") | |
| conn.commit() | |
| cursor.close() | |
| print("β Database migration completed (added new asset columns)") | |
| except Exception as migration_error: | |
| # Migration might fail if columns already exist, that's okay | |
| print(f"Migration note: {migration_error}") | |
| conn.close() | |
| print("β CockroachDB tables created successfully (using direct psycopg2 connection)") | |
| return True | |
| except Exception as raw_error: | |
| print(f"β Table creation failed: {raw_error}") | |
| print("β Database connection works - tables will be created on first use") | |
| return True # Connection works, tables will be created later | |
| else: | |
| # Real error, not just version parsing | |
| raise create_error | |
| except Exception as e: | |
| error_str = str(e) | |
| if "Could not determine version" in error_str: | |
| print("β CockroachDB version parsing issue (non-fatal)") | |
| print("β Database connection works - tables will be created on first use") | |
| return True # Connection works, return True | |
| else: | |
| print(f"Database connection failed: {e}") | |
| return False | |