File size: 3,632 Bytes
5d924ac
 
 
 
ac35357
5d924ac
18cd44b
41341e4
 
 
 
5d924ac
 
 
 
 
 
 
 
 
 
 
 
 
cbf0d7c
74e7eda
5d924ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac35357
41341e4
ac35357
 
 
5d924ac
c9750cf
 
 
 
 
 
 
 
 
c000986
 
c9750cf
 
 
 
 
 
 
edd46d3
18cd44b
c9750cf
5d924ac
 
41341e4
 
 
 
5d924ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41341e4
5d924ac
41341e4
5d924ac
 
 
 
41341e4
 
 
 
 
 
 
 
 
5d924ac
 
6cb2a59
5d924ac
6cb2a59
5d924ac
 
 
 
 
 
 
 
6cb2a59
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from sqlalchemy import create_engine, UnicodeText, DateTime, ForeignKey
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, Session, joinedload, relationship
from sqlalchemy.sql import func
from contextlib import contextmanager
import os
import datetime
import logging
from cachetools import cached, TTLCache


CACHE_TTL_SECONDS = 600


class Base(DeclarativeBase):
    pass


class Article(Base):
    __tablename__ = 'article'

    id: Mapped[int] = mapped_column(primary_key=True)
    title: Mapped[str] = mapped_column(UnicodeText)
    content: Mapped[str] = mapped_column(UnicodeText)
    date: Mapped[datetime.datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
    category: Mapped[str] = mapped_column(UnicodeText, server_default='news')
    image_url: Mapped[str] = mapped_column(UnicodeText, nullable=True, server_default=None)

    sources = relationship('Source', backref='article')

    def __repr__(self) -> str:
        return f'Article(id={self.id}, title={self.title}, content={self.content})'
    

class Source(Base):
    '''

    Represents a source URL for an article.

    '''
    __tablename__ = 'source'

    id: Mapped[int] = mapped_column(primary_key=True)
    url: Mapped[str] = mapped_column(UnicodeText)
    article_id: Mapped[int] = mapped_column(ForeignKey('article.id'))


# get environment variables for database
USE_TURSO = os.environ.get('USE_TURSO', 'false')
TURSO_DATABASE_URL = os.getenv('TURSO_DATABASE_URL')
TURSO_AUTH_TOKEN = os.getenv('TURSO_AUTH_TOKEN')

# create an engine
connected_to_turso = False
if (USE_TURSO == 'true' and TURSO_DATABASE_URL and TURSO_AUTH_TOKEN):
    try:
        engine = create_engine(
            f'sqlite+{TURSO_DATABASE_URL}?secure=true',
            connect_args={
                'auth_token': TURSO_AUTH_TOKEN
            },
            echo=False,
            pool_recycle=7000,
            pool_pre_ping=True,
        )
        Base.metadata.create_all(engine)
        connected_to_turso = True
    except:
        logging.error('Failed to connect to remote Turso database')

if not connected_to_turso:
    logging.warning('Using local SQLite database')
    engine = create_engine('sqlite:///news.db', echo=False)
    Base.metadata.create_all(engine)


# cache to hold articles
article_cache = TTLCache(1, ttl=CACHE_TTL_SECONDS)


@contextmanager
def get_session():
    '''

    Context manager for creating and closing a database session

    '''
    # create a session
    session = Session(engine)
    try:
        yield session
    finally:
        session.close()


def add_article(session: Session, article: Article):
    '''

    Adds a new article to the database

    '''
    session.add(article)


def _retrieve_articles_from_db(session: Session):
    '''

    Returns a list containing all articles from the database

    '''
    return session.query(Article).options(joinedload(Article.sources)).all()


@cached(article_cache)
def get_cached_articles():
    '''

    Returns a list containing all articles from the in-memory cache

    '''
    with get_session() as session:
        return _retrieve_articles_from_db(session=session)


def clear_articles(session: Session):
    '''

    Deletes all articles and sources in the database

    '''
    session.query(Source).delete()
    session.query(Article).delete()


def add_sources(session: Session, sources: list[Source]):
    '''

    Adds the given sources to the database

    '''
    for source in sources:
        session.add(source)