Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import json | |
| from datetime import date, datetime, timedelta | |
| from pathlib import Path | |
| from typing import Iterable | |
| from sqlalchemy import select | |
| from sqlalchemy.orm import Session | |
| from app.config import settings | |
| from app.database import Base, SessionLocal, engine | |
| from app.generated_catalog import extend_seed_data | |
| from app.models import ( | |
| Availability, | |
| Booking, | |
| Listing, | |
| ListingImage, | |
| Message, | |
| MessageThread, | |
| Review, | |
| TaskDefinition, | |
| User, | |
| WishlistItem, | |
| ) | |
| DATE_FORMAT = "%Y-%m-%d" | |
| DATETIME_FORMAT = "%Y-%m-%dT%H:%M:%S" | |
| def parse_date(value: str) -> date: | |
| return datetime.strptime(value, DATE_FORMAT).date() | |
| def parse_datetime(value: str) -> datetime: | |
| return datetime.strptime(value, DATETIME_FORMAT) | |
| def daterange(start_date: date, end_date: date) -> Iterable[date]: | |
| current = start_date | |
| while current < end_date: | |
| yield current | |
| current += timedelta(days=1) | |
| def load_seed_data(path: Path | None = None) -> dict: | |
| seed_path = path or settings.seed_file | |
| with seed_path.open("r", encoding="utf-8") as file: | |
| return extend_seed_data(json.load(file)) | |
| def _generate_availability( | |
| listing: Listing, | |
| availability_start: date, | |
| availability_end: date, | |
| bookings: list[Booking], | |
| ) -> list[Availability]: | |
| blocked_dates = set() | |
| for blocked_range in listing.blocked_ranges: | |
| blocked_start = parse_date(blocked_range["start"]) | |
| blocked_end = parse_date(blocked_range["end"]) | |
| blocked_dates.update(daterange(blocked_start, blocked_end)) | |
| for booking in bookings: | |
| if booking.status == "confirmed": | |
| blocked_dates.update(daterange(booking.check_in, booking.check_out)) | |
| availability_entries = [] | |
| for current_date in daterange(availability_start, availability_end): | |
| availability_entries.append( | |
| Availability( | |
| listing_id=listing.id, | |
| date=current_date, | |
| is_available=current_date not in blocked_dates, | |
| ) | |
| ) | |
| return availability_entries | |
| def seed_database(session: Session, seed_data: dict) -> None: | |
| user_objects: dict[int, User] = {} | |
| listing_objects: dict[int, Listing] = {} | |
| booking_objects: list[Booking] = [] | |
| for user_data in seed_data["users"]: | |
| user = User(**user_data) | |
| session.add(user) | |
| user_objects[user.id] = user | |
| session.flush() | |
| for listing_data in seed_data["listings"]: | |
| image_payloads = listing_data.pop("images", []) | |
| listing = Listing(**listing_data) | |
| session.add(listing) | |
| session.flush() | |
| listing_objects[listing.id] = listing | |
| for image_payload in image_payloads: | |
| session.add(ListingImage(listing_id=listing.id, **image_payload)) | |
| for review_data in seed_data["reviews"]: | |
| review = Review( | |
| **{ | |
| **review_data, | |
| "created_at": parse_datetime(review_data["created_at"]), | |
| } | |
| ) | |
| session.add(review) | |
| session.flush() | |
| for booking_data in seed_data["bookings"]: | |
| booking = Booking( | |
| **{ | |
| **booking_data, | |
| "check_in": parse_date(booking_data["check_in"]), | |
| "check_out": parse_date(booking_data["check_out"]), | |
| "created_at": parse_datetime(booking_data["created_at"]), | |
| } | |
| ) | |
| session.add(booking) | |
| booking_objects.append(booking) | |
| for wishlist_data in seed_data["wishlists"]: | |
| session.add(WishlistItem(**wishlist_data)) | |
| for thread_data in seed_data["message_threads"]: | |
| messages = thread_data.pop("messages", []) | |
| thread = MessageThread( | |
| **{ | |
| **thread_data, | |
| "last_message_at": parse_datetime(thread_data["last_message_at"]), | |
| } | |
| ) | |
| session.add(thread) | |
| session.flush() | |
| for message_data in messages: | |
| session.add( | |
| Message( | |
| thread_id=thread.id, | |
| sender_id=message_data["sender_id"], | |
| body=message_data["body"], | |
| created_at=parse_datetime(message_data["created_at"]), | |
| ) | |
| ) | |
| for task_data in seed_data["tasks"]: | |
| session.add(TaskDefinition(**task_data)) | |
| availability_start = parse_date(seed_data["availability_window"]["start"]) | |
| availability_end = parse_date(seed_data["availability_window"]["end"]) | |
| session.flush() | |
| bookings_by_listing: dict[int, list[Booking]] = {} | |
| for booking in booking_objects: | |
| bookings_by_listing.setdefault(booking.listing_id, []).append(booking) | |
| for listing in listing_objects.values(): | |
| session.add_all( | |
| _generate_availability( | |
| listing=listing, | |
| availability_start=availability_start, | |
| availability_end=availability_end, | |
| bookings=bookings_by_listing.get(listing.id, []), | |
| ) | |
| ) | |
| def reset_database() -> None: | |
| Base.metadata.drop_all(bind=engine) | |
| Base.metadata.create_all(bind=engine) | |
| seed_data = load_seed_data() | |
| with SessionLocal() as session: | |
| seed_database(session, seed_data) | |
| session.commit() | |
| def ensure_database_seeded() -> None: | |
| Base.metadata.create_all(bind=engine) | |
| with SessionLocal() as session: | |
| has_users = session.scalar(select(User.id).limit(1)) | |
| if has_users is None: | |
| seed_database(session, load_seed_data()) | |
| session.commit() | |