Spaces:
Sleeping
Sleeping
File size: 4,766 Bytes
5ff3858 30d3465 5ff3858 30d3465 5ff3858 30d3465 5ff3858 30d3465 5ff3858 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 | import unittest
from datetime import datetime
from sqlalchemy import select
from sqlalchemy.orm import joinedload
from starlette.requests import Request
from app.database import SessionLocal
from app.main import (
app,
_parse_optional_float,
_parse_optional_int,
block_listing_dates,
host_dashboard,
merge_ranges,
)
from app.models import Booking, Listing, TaskDefinition
from app.seed import reset_database
from app.tasks import evaluate_task
class WebArenaAirbnbTests(unittest.TestCase):
def setUp(self):
reset_database()
def _build_request(self, path: str, user_id: int | None = None, method: str = "GET") -> Request:
scope = {
"type": "http",
"method": method,
"path": path,
"headers": [],
"query_string": b"",
"session": {"user_id": user_id} if user_id is not None else {},
"client": ("127.0.0.1", 12345),
"server": ("testserver", 80),
"scheme": "http",
"http_version": "1.1",
"app": app,
}
return Request(scope)
def test_seeded_task_catalog_exists(self):
with SessionLocal() as db:
tasks = db.scalars(
select(TaskDefinition).options(joinedload(TaskDefinition.persona)).order_by(TaskDefinition.id)
).all()
self.assertEqual(len(tasks), 10)
self.assertEqual(tasks[0].persona.email, "avery@example.com")
def test_catalog_contains_thirty_plus_stays(self):
with SessionLocal() as db:
listings = db.scalars(select(Listing)).all()
self.assertGreaterEqual(len(listings), 30)
def test_booking_task_evaluator_detects_completed_goal(self):
with SessionLocal() as db:
db.add(
Booking(
confirmation_code="BKG-2999",
listing_id=1,
guest_id=1,
check_in=datetime.strptime("2026-04-14", "%Y-%m-%d").date(),
check_out=datetime.strptime("2026-04-17", "%Y-%m-%d").date(),
guests=2,
total_price=815,
status="confirmed",
created_at=datetime.utcnow(),
)
)
db.commit()
task = db.scalar(
select(TaskDefinition).where(TaskDefinition.id == 1).options(joinedload(TaskDefinition.persona))
)
result = evaluate_task(db, task)
self.assertTrue(result["success"])
def test_merge_ranges_combines_overlapping_blocks(self):
merged = merge_ranges(
[
{"start": "2026-05-08", "end": "2026-05-10"},
{"start": "2026-05-09", "end": "2026-05-12"},
{"start": "2026-06-01", "end": "2026-06-03"},
]
)
self.assertEqual(
merged,
[
{"start": "2026-05-08", "end": "2026-05-12"},
{"start": "2026-06-01", "end": "2026-06-03"},
],
)
def test_optional_search_filters_ignore_empty_strings(self):
self.assertIsNone(_parse_optional_int(""))
self.assertIsNone(_parse_optional_int(" "))
self.assertIsNone(_parse_optional_float(""))
self.assertEqual(_parse_optional_int("4"), 4)
self.assertEqual(_parse_optional_float("4.5"), 4.5)
def test_host_dashboard_renders_host_management_sections(self):
with SessionLocal() as db:
response = host_dashboard(self._build_request("/host", user_id=2), db)
body = response.body.decode()
self.assertIn("Manage listings and reservations", body)
self.assertIn("Existing blocks", body)
self.assertIn("Upcoming reservations", body)
self.assertIn("Annex Glass Loft", body)
def test_block_listing_dates_updates_host_calendar_and_redirects(self):
with SessionLocal() as db:
response = block_listing_dates(
request=self._build_request("/host/listings/1/block", user_id=2, method="POST"),
listing_id=1,
start_date="2026-04-04",
end_date="2026-04-07",
db=db,
)
listing = db.get(Listing, 1)
self.assertEqual(response.status_code, 303)
self.assertEqual(response.headers["location"], "/host?notice=Blocked%20dates%20on%20Annex%20Glass%20Loft.")
self.assertEqual(
listing.blocked_ranges,
[
{"start": "2026-04-03", "end": "2026-04-07"},
{"start": "2026-06-12", "end": "2026-06-15"},
],
)
if __name__ == "__main__":
unittest.main()
|