TravelMap / app /seed.py
Jack
Initial commit
5ff3858
from __future__ import annotations
import json
from datetime import date, datetime, timedelta
from pathlib import Path
from typing import Iterable
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.config import settings
from app.database import Base, SessionLocal, engine
from app.generated_catalog import extend_seed_data
from app.models import (
Availability,
Booking,
Listing,
ListingImage,
Message,
MessageThread,
Review,
TaskDefinition,
User,
WishlistItem,
)
DATE_FORMAT = "%Y-%m-%d"
DATETIME_FORMAT = "%Y-%m-%dT%H:%M:%S"
def parse_date(value: str) -> date:
return datetime.strptime(value, DATE_FORMAT).date()
def parse_datetime(value: str) -> datetime:
return datetime.strptime(value, DATETIME_FORMAT)
def daterange(start_date: date, end_date: date) -> Iterable[date]:
current = start_date
while current < end_date:
yield current
current += timedelta(days=1)
def load_seed_data(path: Path | None = None) -> dict:
seed_path = path or settings.seed_file
with seed_path.open("r", encoding="utf-8") as file:
return extend_seed_data(json.load(file))
def _generate_availability(
listing: Listing,
availability_start: date,
availability_end: date,
bookings: list[Booking],
) -> list[Availability]:
blocked_dates = set()
for blocked_range in listing.blocked_ranges:
blocked_start = parse_date(blocked_range["start"])
blocked_end = parse_date(blocked_range["end"])
blocked_dates.update(daterange(blocked_start, blocked_end))
for booking in bookings:
if booking.status == "confirmed":
blocked_dates.update(daterange(booking.check_in, booking.check_out))
availability_entries = []
for current_date in daterange(availability_start, availability_end):
availability_entries.append(
Availability(
listing_id=listing.id,
date=current_date,
is_available=current_date not in blocked_dates,
)
)
return availability_entries
def seed_database(session: Session, seed_data: dict) -> None:
user_objects: dict[int, User] = {}
listing_objects: dict[int, Listing] = {}
booking_objects: list[Booking] = []
for user_data in seed_data["users"]:
user = User(**user_data)
session.add(user)
user_objects[user.id] = user
session.flush()
for listing_data in seed_data["listings"]:
image_payloads = listing_data.pop("images", [])
listing = Listing(**listing_data)
session.add(listing)
session.flush()
listing_objects[listing.id] = listing
for image_payload in image_payloads:
session.add(ListingImage(listing_id=listing.id, **image_payload))
for review_data in seed_data["reviews"]:
review = Review(
**{
**review_data,
"created_at": parse_datetime(review_data["created_at"]),
}
)
session.add(review)
session.flush()
for booking_data in seed_data["bookings"]:
booking = Booking(
**{
**booking_data,
"check_in": parse_date(booking_data["check_in"]),
"check_out": parse_date(booking_data["check_out"]),
"created_at": parse_datetime(booking_data["created_at"]),
}
)
session.add(booking)
booking_objects.append(booking)
for wishlist_data in seed_data["wishlists"]:
session.add(WishlistItem(**wishlist_data))
for thread_data in seed_data["message_threads"]:
messages = thread_data.pop("messages", [])
thread = MessageThread(
**{
**thread_data,
"last_message_at": parse_datetime(thread_data["last_message_at"]),
}
)
session.add(thread)
session.flush()
for message_data in messages:
session.add(
Message(
thread_id=thread.id,
sender_id=message_data["sender_id"],
body=message_data["body"],
created_at=parse_datetime(message_data["created_at"]),
)
)
for task_data in seed_data["tasks"]:
session.add(TaskDefinition(**task_data))
availability_start = parse_date(seed_data["availability_window"]["start"])
availability_end = parse_date(seed_data["availability_window"]["end"])
session.flush()
bookings_by_listing: dict[int, list[Booking]] = {}
for booking in booking_objects:
bookings_by_listing.setdefault(booking.listing_id, []).append(booking)
for listing in listing_objects.values():
session.add_all(
_generate_availability(
listing=listing,
availability_start=availability_start,
availability_end=availability_end,
bookings=bookings_by_listing.get(listing.id, []),
)
)
def reset_database() -> None:
Base.metadata.drop_all(bind=engine)
Base.metadata.create_all(bind=engine)
seed_data = load_seed_data()
with SessionLocal() as session:
seed_database(session, seed_data)
session.commit()
def ensure_database_seeded() -> None:
Base.metadata.create_all(bind=engine)
with SessionLocal() as session:
has_users = session.scalar(select(User.id).limit(1))
if has_users is None:
seed_database(session, load_seed_data())
session.commit()