|
|
import os |
|
|
from sqlalchemy import create_engine, text |
|
|
from sqlalchemy.orm import sessionmaker, declarative_base |
|
|
from sqlalchemy.pool import NullPool |
|
|
|
|
|
|
|
|
try: |
|
|
import cockroachdb.sqlalchemy.dialect |
|
|
COCKROACHDB_AVAILABLE = True |
|
|
except ImportError: |
|
|
COCKROACHDB_AVAILABLE = False |
|
|
|
|
|
|
|
|
|
|
|
ORIGINAL_DATABASE_URL = os.getenv( |
|
|
"DATABASE_URL", |
|
|
"sqlite:///./postgen.db" |
|
|
) |
|
|
DATABASE_URL = ORIGINAL_DATABASE_URL |
|
|
|
|
|
|
|
|
if ORIGINAL_DATABASE_URL.startswith("postgresql://") or ORIGINAL_DATABASE_URL.startswith("postgres://") or ORIGINAL_DATABASE_URL.startswith("cockroachdb://"): |
|
|
|
|
|
is_cockroach = "cockroachlabs" in ORIGINAL_DATABASE_URL.lower() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cert_path = os.path.expanduser("~/.postgresql/root.crt") |
|
|
|
|
|
if "sslmode" not in DATABASE_URL: |
|
|
separator = "&" if "?" in DATABASE_URL else "?" |
|
|
|
|
|
|
|
|
DATABASE_URL = f"{DATABASE_URL}{separator}sslmode=require" |
|
|
elif "sslmode=verify-full" in DATABASE_URL and not os.path.exists(cert_path): |
|
|
|
|
|
DATABASE_URL = DATABASE_URL.replace("sslmode=verify-full", "sslmode=require") |
|
|
print("⚠ Certificate file not found, using sslmode=require instead of verify-full") |
|
|
|
|
|
|
|
|
if is_cockroach and COCKROACHDB_AVAILABLE: |
|
|
|
|
|
DATABASE_URL = DATABASE_URL.replace("postgresql://", "cockroachdb://", 1) |
|
|
DATABASE_URL = DATABASE_URL.replace("postgres://", "cockroachdb://", 1) |
|
|
|
|
|
|
|
|
engine = create_engine( |
|
|
DATABASE_URL, |
|
|
poolclass=NullPool, |
|
|
echo=False, |
|
|
connect_args={} |
|
|
) |
|
|
else: |
|
|
|
|
|
engine = create_engine( |
|
|
DATABASE_URL, |
|
|
connect_args={"check_same_thread": False} if "sqlite" in DATABASE_URL else {} |
|
|
) |
|
|
|
|
|
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) |
|
|
|
|
|
Base = declarative_base() |
|
|
|
|
|
def get_db(): |
|
|
"""Dependency to get database session""" |
|
|
db = SessionLocal() |
|
|
try: |
|
|
yield db |
|
|
finally: |
|
|
db.close() |
|
|
|
|
|
def get_direct_psycopg2_connection(): |
|
|
"""Get a direct psycopg2 connection bypassing SQLAlchemy version parsing""" |
|
|
try: |
|
|
import psycopg2 |
|
|
from urllib.parse import urlparse, parse_qs |
|
|
|
|
|
|
|
|
if not (ORIGINAL_DATABASE_URL.startswith("postgresql://") or |
|
|
ORIGINAL_DATABASE_URL.startswith("postgres://") or |
|
|
ORIGINAL_DATABASE_URL.startswith("cockroachdb://")): |
|
|
return None |
|
|
|
|
|
|
|
|
url_for_parsing = ORIGINAL_DATABASE_URL.replace("cockroachdb://", "postgresql://") |
|
|
parsed = urlparse(url_for_parsing) |
|
|
dbname = parsed.path[1:] if parsed.path else "defaultdb" |
|
|
user = parsed.username |
|
|
password = parsed.password |
|
|
host = parsed.hostname |
|
|
port = parsed.port or 26257 |
|
|
|
|
|
|
|
|
params = parse_qs(parsed.query) |
|
|
sslmode_list = params.get('sslmode', ['require']) |
|
|
sslmode = sslmode_list[0] if sslmode_list else 'require' |
|
|
|
|
|
|
|
|
conn = psycopg2.connect( |
|
|
dbname=dbname, |
|
|
user=user, |
|
|
password=password, |
|
|
host=host, |
|
|
port=port, |
|
|
sslmode=sslmode |
|
|
) |
|
|
return conn |
|
|
except Exception as e: |
|
|
print(f"Failed to create direct psycopg2 connection: {e}") |
|
|
return None |
|
|
|
|
|
def ensure_default_user(): |
|
|
"""Ensure a default user (id=1) exists in the database""" |
|
|
try: |
|
|
conn = get_direct_psycopg2_connection() |
|
|
if not conn: |
|
|
return 1 |
|
|
|
|
|
try: |
|
|
cursor = conn.cursor() |
|
|
|
|
|
cursor.execute("SELECT id FROM users WHERE id = 1") |
|
|
if cursor.fetchone(): |
|
|
cursor.close() |
|
|
conn.close() |
|
|
return 1 |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
cursor.execute(""" |
|
|
INSERT INTO users (id, email, name, created_at) |
|
|
VALUES (1, 'default@postgen.app', 'Default User', NOW()) |
|
|
ON CONFLICT (id) DO NOTHING |
|
|
""") |
|
|
except Exception: |
|
|
|
|
|
try: |
|
|
cursor.execute(""" |
|
|
INSERT INTO users (id, email, name, created_at) |
|
|
VALUES (1, 'default@postgen.app', 'Default User', NOW()) |
|
|
""") |
|
|
except Exception as insert_error: |
|
|
|
|
|
cursor.execute("SELECT id FROM users WHERE id = 1 OR email = 'default@postgen.app' LIMIT 1") |
|
|
row = cursor.fetchone() |
|
|
if row: |
|
|
cursor.close() |
|
|
conn.close() |
|
|
return row[0] |
|
|
|
|
|
raise insert_error |
|
|
|
|
|
conn.commit() |
|
|
cursor.close() |
|
|
conn.close() |
|
|
return 1 |
|
|
except Exception as e: |
|
|
|
|
|
try: |
|
|
cursor.execute("SELECT id FROM users LIMIT 1") |
|
|
row = cursor.fetchone() |
|
|
cursor.close() |
|
|
conn.close() |
|
|
if row: |
|
|
return row[0] |
|
|
except: |
|
|
if conn: |
|
|
conn.close() |
|
|
print(f"Warning: Could not ensure default user: {e}") |
|
|
return 1 |
|
|
except Exception as e: |
|
|
print(f"Error ensuring default user: {e}") |
|
|
return 1 |
|
|
|
|
|
def init_db(): |
|
|
"""Initialize database tables""" |
|
|
try: |
|
|
from app.models import User, Integration, Asset, Post, Campaign |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
Base.metadata.create_all(bind=engine) |
|
|
print("✓ Database tables created successfully") |
|
|
return True |
|
|
except Exception as create_error: |
|
|
error_str = str(create_error) |
|
|
|
|
|
if "Could not determine version" in error_str: |
|
|
|
|
|
|
|
|
try: |
|
|
import psycopg2 |
|
|
from urllib.parse import urlparse, parse_qs |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
url_for_parsing = ORIGINAL_DATABASE_URL.replace("cockroachdb://", "postgresql://") |
|
|
parsed = urlparse(url_for_parsing) |
|
|
dbname = parsed.path[1:] if parsed.path else "defaultdb" |
|
|
user = parsed.username |
|
|
password = parsed.password |
|
|
host = parsed.hostname |
|
|
port = parsed.port or 26257 |
|
|
|
|
|
|
|
|
params = parse_qs(parsed.query) |
|
|
sslmode_list = params.get('sslmode', ['require']) |
|
|
sslmode = sslmode_list[0] if sslmode_list else 'require' |
|
|
|
|
|
|
|
|
conn = psycopg2.connect( |
|
|
dbname=dbname, |
|
|
user=user, |
|
|
password=password, |
|
|
host=host, |
|
|
port=port, |
|
|
sslmode=sslmode |
|
|
) |
|
|
|
|
|
cursor = conn.cursor() |
|
|
|
|
|
tables_sql = [ |
|
|
"""CREATE TABLE IF NOT EXISTS users ( |
|
|
id SERIAL PRIMARY KEY, |
|
|
email VARCHAR UNIQUE, |
|
|
name VARCHAR, |
|
|
created_at TIMESTAMP DEFAULT NOW() |
|
|
)""", |
|
|
"""CREATE TABLE IF NOT EXISTS integrations ( |
|
|
id SERIAL PRIMARY KEY, |
|
|
user_id INTEGER REFERENCES users(id), |
|
|
provider VARCHAR, |
|
|
access_token TEXT, |
|
|
refresh_token TEXT, |
|
|
expires_at TIMESTAMP, |
|
|
account_info JSONB, |
|
|
connected BOOLEAN DEFAULT FALSE, |
|
|
created_at TIMESTAMP DEFAULT NOW(), |
|
|
updated_at TIMESTAMP DEFAULT NOW() |
|
|
)""", |
|
|
"""CREATE TABLE IF NOT EXISTS assets ( |
|
|
id SERIAL PRIMARY KEY, |
|
|
user_id INTEGER REFERENCES users(id), |
|
|
name VARCHAR, |
|
|
file_path VARCHAR, |
|
|
file_type VARCHAR, |
|
|
product_category VARCHAR, |
|
|
sub_category VARCHAR, |
|
|
size INTEGER, |
|
|
extra_metadata JSONB, |
|
|
extracted_content JSONB, |
|
|
analysis_status VARCHAR DEFAULT 'pending', |
|
|
analyzed_at TIMESTAMP, |
|
|
created_at TIMESTAMP DEFAULT NOW() |
|
|
)""", |
|
|
"""CREATE TABLE IF NOT EXISTS posts ( |
|
|
id SERIAL PRIMARY KEY, |
|
|
user_id INTEGER REFERENCES users(id), |
|
|
title VARCHAR, |
|
|
content TEXT, |
|
|
post_type VARCHAR, |
|
|
product_category VARCHAR, |
|
|
scheduled_date TIMESTAMP, |
|
|
status VARCHAR, |
|
|
linkedin_post_id VARCHAR, |
|
|
canva_design_id VARCHAR, |
|
|
assets JSONB, |
|
|
extra_metadata JSONB, |
|
|
created_at TIMESTAMP DEFAULT NOW(), |
|
|
updated_at TIMESTAMP DEFAULT NOW() |
|
|
)""", |
|
|
"""CREATE TABLE IF NOT EXISTS campaigns ( |
|
|
id SERIAL PRIMARY KEY, |
|
|
user_id INTEGER REFERENCES users(id), |
|
|
name VARCHAR, |
|
|
date_range_start TIMESTAMP, |
|
|
date_range_end TIMESTAMP, |
|
|
products JSONB, |
|
|
post_types JSONB, |
|
|
posts_per_week INTEGER, |
|
|
status VARCHAR, |
|
|
created_at TIMESTAMP DEFAULT NOW(), |
|
|
updated_at TIMESTAMP DEFAULT NOW() |
|
|
)""" |
|
|
] |
|
|
for sql in tables_sql: |
|
|
cursor.execute(sql) |
|
|
conn.commit() |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
cursor = conn.cursor() |
|
|
|
|
|
|
|
|
cursor.execute(""" |
|
|
SELECT column_name |
|
|
FROM information_schema.columns |
|
|
WHERE table_name='assets' AND column_name='extracted_content' |
|
|
""") |
|
|
if not cursor.fetchone(): |
|
|
cursor.execute("ALTER TABLE assets ADD COLUMN extracted_content JSONB") |
|
|
print("✓ Added extracted_content column") |
|
|
|
|
|
cursor.execute(""" |
|
|
SELECT column_name |
|
|
FROM information_schema.columns |
|
|
WHERE table_name='assets' AND column_name='analysis_status' |
|
|
""") |
|
|
if not cursor.fetchone(): |
|
|
cursor.execute("ALTER TABLE assets ADD COLUMN analysis_status VARCHAR DEFAULT 'pending'") |
|
|
print("✓ Added analysis_status column") |
|
|
|
|
|
cursor.execute(""" |
|
|
SELECT column_name |
|
|
FROM information_schema.columns |
|
|
WHERE table_name='assets' AND column_name='analyzed_at' |
|
|
""") |
|
|
if not cursor.fetchone(): |
|
|
cursor.execute("ALTER TABLE assets ADD COLUMN analyzed_at TIMESTAMP") |
|
|
print("✓ Added analyzed_at column") |
|
|
|
|
|
conn.commit() |
|
|
cursor.close() |
|
|
print("✓ Database migration completed (added new asset columns)") |
|
|
except Exception as migration_error: |
|
|
|
|
|
print(f"Migration note: {migration_error}") |
|
|
|
|
|
conn.close() |
|
|
print("✓ CockroachDB tables created successfully (using direct psycopg2 connection)") |
|
|
return True |
|
|
except Exception as raw_error: |
|
|
print(f"⚠ Table creation failed: {raw_error}") |
|
|
print("✓ Database connection works - tables will be created on first use") |
|
|
return True |
|
|
else: |
|
|
|
|
|
raise create_error |
|
|
except Exception as e: |
|
|
error_str = str(e) |
|
|
if "Could not determine version" in error_str: |
|
|
print("⚠ CockroachDB version parsing issue (non-fatal)") |
|
|
print("✓ Database connection works - tables will be created on first use") |
|
|
return True |
|
|
else: |
|
|
print(f"Database connection failed: {e}") |
|
|
return False |
|
|
|
|
|
|