Spaces:
Sleeping
Sleeping
| import os | |
| from typing import Generator | |
| from sqlmodel import SQLModel, create_engine, Session | |
| from sqlalchemy import inspect, text | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| DATABASE_URL = os.getenv( | |
| "DATABASE_URL", | |
| "postgresql://postgres:postgres@localhost:5432/paper_insight" | |
| ) | |
| engine = create_engine(DATABASE_URL, echo=False) | |
| def create_db_and_tables(): | |
| """Create all database tables.""" | |
| SQLModel.metadata.create_all(engine) | |
| def ensure_appsettings_schema(): | |
| """Ensure AppSettings has expected columns for legacy databases.""" | |
| inspector = inspect(engine) | |
| if "appsettings" not in inspector.get_table_names(): | |
| return | |
| columns = {col["name"] for col in inspector.get_columns("appsettings")} | |
| added = set() | |
| ddl_statements = [] | |
| if "research_focus" not in columns: | |
| ddl_statements.append("ALTER TABLE appsettings ADD COLUMN research_focus TEXT") | |
| added.add("research_focus") | |
| if "focus_keywords" not in columns: | |
| ddl_statements.append("ALTER TABLE appsettings ADD COLUMN focus_keywords JSON") | |
| added.add("focus_keywords") | |
| if "system_prompt" not in columns: | |
| ddl_statements.append("ALTER TABLE appsettings ADD COLUMN system_prompt TEXT") | |
| added.add("system_prompt") | |
| if "arxiv_categories" not in columns: | |
| ddl_statements.append("ALTER TABLE appsettings ADD COLUMN arxiv_categories JSON") | |
| added.add("arxiv_categories") | |
| if not ddl_statements and not columns: | |
| return | |
| final_columns = columns | added | |
| with engine.begin() as conn: | |
| for stmt in ddl_statements: | |
| conn.execute(text(stmt)) | |
| if "research_focus" in final_columns: | |
| conn.execute( | |
| text("UPDATE appsettings SET research_focus = '' WHERE research_focus IS NULL") | |
| ) | |
| if "system_prompt" in final_columns: | |
| conn.execute( | |
| text("UPDATE appsettings SET system_prompt = '' WHERE system_prompt IS NULL") | |
| ) | |
| if "focus_keywords" in final_columns: | |
| conn.execute( | |
| text("UPDATE appsettings SET focus_keywords = '[]' WHERE focus_keywords IS NULL") | |
| ) | |
| if "arxiv_categories" in final_columns: | |
| conn.execute( | |
| text( | |
| "UPDATE appsettings SET arxiv_categories = " | |
| "'[\"cs.CV\",\"cs.LG\"]' WHERE arxiv_categories IS NULL" | |
| ) | |
| ) | |
| def ensure_paper_schema(): | |
| """Ensure Paper has expected columns for legacy databases.""" | |
| inspector = inspect(engine) | |
| table_name = None | |
| if "paper" in inspector.get_table_names(): | |
| table_name = "paper" | |
| elif "papers" in inspector.get_table_names(): | |
| table_name = "papers" | |
| if not table_name: | |
| return | |
| columns = {col["name"] for col in inspector.get_columns(table_name)} | |
| added = set() | |
| ddl_statements = [] | |
| if "processing_status" not in columns: | |
| ddl_statements.append( | |
| f"ALTER TABLE {table_name} ADD COLUMN processing_status TEXT" | |
| ) | |
| added.add("processing_status") | |
| final_columns = columns | added | |
| with engine.begin() as conn: | |
| for stmt in ddl_statements: | |
| conn.execute(text(stmt)) | |
| if "processing_status" in final_columns: | |
| conn.execute( | |
| text( | |
| f"UPDATE {table_name} " | |
| "SET processing_status = CASE " | |
| "WHEN is_processed THEN 'processed' ELSE 'pending' END " | |
| "WHERE processing_status IS NULL" | |
| ) | |
| ) | |
| conn.execute( | |
| text( | |
| f"UPDATE {table_name} " | |
| "SET processing_status = 'skipped' " | |
| "WHERE is_processed = TRUE " | |
| "AND relevance_score IS NOT NULL " | |
| "AND relevance_score < 5 " | |
| "AND processing_status = 'processed'" | |
| ) | |
| ) | |
| def get_session() -> Generator[Session, None, None]: | |
| """Dependency for getting database session.""" | |
| with Session(engine) as session: | |
| yield session | |
| def get_sync_session() -> Session: | |
| """Get a synchronous session for non-FastAPI contexts.""" | |
| return Session(engine) | |