focal / app /main.py
michaelkri
Improved reliability of database connections
c000986
from fastapi import FastAPI, Request, BackgroundTasks
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from contextlib import asynccontextmanager
from apscheduler.schedulers.background import BackgroundScheduler
import threading
import os
import logging
import itertools
from datetime import datetime, timedelta
from .database import get_cached_articles, Article
from .update_news import update_news
DEFAULT_UPDATE_HOURS = '0,6,12,18' # UTC time
FORCED_UPDATE_HOURS = 8 # Force update if articles are older than 8 hours
if os.environ.get('DEBUG', 'false') == 'true':
root_logger = logging.getLogger()
root_logger.setLevel(logging.DEBUG)
# set update interval from environment variable if available
update_schedule_hours = os.environ.get('UPDATE_HOURS', DEFAULT_UPDATE_HOURS)
# for updating the news feed periodically
scheduler = BackgroundScheduler()
# prevent initiating multiple concurrent updates
is_updating = threading.Lock()
def safe_update_news():
'''
Wrapper for update_news to ensure only one instance runs at a time.
'''
# check whether there is no update currently in progress
if not is_updating.locked():
# update news in background
with is_updating:
update_news()
@asynccontextmanager
async def lifespan(app: FastAPI):
# update news periodically
scheduler.add_job(
safe_update_news,
'cron',
hour=update_schedule_hours,
minute=0,
id='update_task',
replace_existing=True
)
scheduler.start()
yield
# stop the scheduler when closing the server
scheduler.shutdown()
app = FastAPI(lifespan=lifespan)
app.mount("/static", StaticFiles(directory="app/static"), name="static")
templates = Jinja2Templates(directory='app/templates')
if os.getenv('DEBUG') == 'true':
@app.get('/update')
async def update_articles(request: Request, background_tasks: BackgroundTasks):
background_tasks.add_task(safe_update_news)
return {'response': 'ok'}
def group_articles_by_category(articles: list[Article]) -> dict:
sorted_articles = sorted(articles, key=lambda a: a.category)
grouped_articles = {}
for category, category_articles in itertools.groupby(sorted_articles, lambda a: a.category):
grouped_articles[category] = list(category_articles)
return grouped_articles
@app.get('/')
async def read_root(request: Request, background_tasks: BackgroundTasks):
# retrieve articles from database
articles = get_cached_articles()
# how many hours since the last update
last_updated_hours = -1
# no articles yet
if not articles:
# update news in background
background_tasks.add_task(safe_update_news)
else:
last_updated_hours = int((datetime.now() - articles[0].date).total_seconds() // 3600)
# Force update if articles are older than treshold (due to missed/failed update)
if last_updated_hours >= FORCED_UPDATE_HOURS:
background_tasks.add_task(safe_update_news)
categorized_articles = group_articles_by_category(articles)
return templates.TemplateResponse(
'index.html',
{
'request': request,
'categorized_articles': categorized_articles,
'update_time': -1 if is_updating.locked() else last_updated_hours
}
)