File size: 4,297 Bytes
353b9f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import os
from typing import Generator
from sqlmodel import SQLModel, create_engine, Session
from sqlalchemy import inspect, text
from dotenv import load_dotenv

load_dotenv()

DATABASE_URL = os.getenv(
    "DATABASE_URL",
    "postgresql://postgres:postgres@localhost:5432/paper_insight"
)

engine = create_engine(DATABASE_URL, echo=False)


def create_db_and_tables():
    """Create all database tables."""
    SQLModel.metadata.create_all(engine)


def ensure_appsettings_schema():
    """Ensure AppSettings has expected columns for legacy databases."""
    inspector = inspect(engine)
    if "appsettings" not in inspector.get_table_names():
        return

    columns = {col["name"] for col in inspector.get_columns("appsettings")}
    added = set()
    ddl_statements = []

    if "research_focus" not in columns:
        ddl_statements.append("ALTER TABLE appsettings ADD COLUMN research_focus TEXT")
        added.add("research_focus")
    if "focus_keywords" not in columns:
        ddl_statements.append("ALTER TABLE appsettings ADD COLUMN focus_keywords JSON")
        added.add("focus_keywords")
    if "system_prompt" not in columns:
        ddl_statements.append("ALTER TABLE appsettings ADD COLUMN system_prompt TEXT")
        added.add("system_prompt")
    if "arxiv_categories" not in columns:
        ddl_statements.append("ALTER TABLE appsettings ADD COLUMN arxiv_categories JSON")
        added.add("arxiv_categories")

    if not ddl_statements and not columns:
        return

    final_columns = columns | added
    with engine.begin() as conn:
        for stmt in ddl_statements:
            conn.execute(text(stmt))

        if "research_focus" in final_columns:
            conn.execute(
                text("UPDATE appsettings SET research_focus = '' WHERE research_focus IS NULL")
            )
        if "system_prompt" in final_columns:
            conn.execute(
                text("UPDATE appsettings SET system_prompt = '' WHERE system_prompt IS NULL")
            )
        if "focus_keywords" in final_columns:
            conn.execute(
                text("UPDATE appsettings SET focus_keywords = '[]' WHERE focus_keywords IS NULL")
            )
        if "arxiv_categories" in final_columns:
            conn.execute(
                text(
                    "UPDATE appsettings SET arxiv_categories = "
                    "'[\"cs.CV\",\"cs.LG\"]' WHERE arxiv_categories IS NULL"
                )
            )


def ensure_paper_schema():
    """Ensure Paper has expected columns for legacy databases."""
    inspector = inspect(engine)
    table_name = None
    if "paper" in inspector.get_table_names():
        table_name = "paper"
    elif "papers" in inspector.get_table_names():
        table_name = "papers"

    if not table_name:
        return

    columns = {col["name"] for col in inspector.get_columns(table_name)}
    added = set()
    ddl_statements = []

    if "processing_status" not in columns:
        ddl_statements.append(
            f"ALTER TABLE {table_name} ADD COLUMN processing_status TEXT"
        )
        added.add("processing_status")

    final_columns = columns | added
    with engine.begin() as conn:
        for stmt in ddl_statements:
            conn.execute(text(stmt))

        if "processing_status" in final_columns:
            conn.execute(
                text(
                    f"UPDATE {table_name} "
                    "SET processing_status = CASE "
                    "WHEN is_processed THEN 'processed' ELSE 'pending' END "
                    "WHERE processing_status IS NULL"
                )
            )
            conn.execute(
                text(
                    f"UPDATE {table_name} "
                    "SET processing_status = 'skipped' "
                    "WHERE is_processed = TRUE "
                    "AND relevance_score IS NOT NULL "
                    "AND relevance_score < 5 "
                    "AND processing_status = 'processed'"
                )
            )


def get_session() -> Generator[Session, None, None]:
    """Dependency for getting database session."""
    with Session(engine) as session:
        yield session


def get_sync_session() -> Session:
    """Get a synchronous session for non-FastAPI contexts."""
    return Session(engine)