customs-data / tests /baseline /test_api_contracts.py
3v324v23's picture
Enhances platform with robust monitoring and notifications
4e22b4d
Raw
History Blame Contribute Delete
6.59 kB
import asyncio
from types import SimpleNamespace
import pytest
from fastapi import HTTPException
from apps.api.routers.bi import get_market_share, get_supply_chain_flow
from apps.api.routers.export import ExportRequest, request_async_export
from apps.api.routers.subscription import (
SubscriptionRequest,
create_subscription,
list_subscriptions,
)
from apps.api.routers.trade import search_trade_records
from apps.api.schemas.trade import TradeQueryRequest
class _QueryResult:
def __init__(self, total=None, records=None):
self._total = total
self._records = records or []
def scalar(self):
return self._total
def scalars(self):
return self
def all(self):
return self._records
class _FakeTradeDB:
def __init__(self, records):
self.records = records
self.statements = []
async def execute(self, stmt):
self.statements.append(stmt)
if len(self.statements) == 1:
return _QueryResult(total=len(self.records))
return _QueryResult(records=self.records)
class _FakeSubscriptionDB:
def __init__(self, rows=None):
self.rows = rows or []
self.added = []
self.committed = False
self.executed = []
def add(self, obj):
self.added.append(obj)
async def commit(self):
self.committed = True
async def execute(self, stmt):
self.executed.append(stmt)
return _QueryResult(records=self.rows)
class _DelayTask:
def __init__(self, task_id="task-123"):
self.task_id = task_id
self.calls = []
def delay(self, *args, **kwargs):
self.calls.append((args, kwargs))
return SimpleNamespace(id=self.task_id)
def test_trade_search_applies_filters_and_pagination():
records = [
SimpleNamespace(
record_id="r1",
source_record_id="raw1",
source_country="BR",
trade_direction="import",
trade_date="2025-01-01",
importer_name="Acme",
exporter_name=None,
hs_code="851712",
product_name="phones",
amount=100.0,
currency="USD",
weight=10.0,
weight_unit="KG",
origin_country="CN",
destination_country="BR",
departure_port=None,
arrival_port=None,
transport_mode="SEA",
)
]
db = _FakeTradeDB(records)
query = TradeQueryRequest(source_country="BR", trade_direction="import", hs_code="8517", page=2, limit=5)
result = asyncio.run(search_trade_records(query=query, db=db))
assert result.total == 1
assert len(result.items) == 1
assert result.items[0].record_id == "r1"
compiled = str(db.statements[1].compile(compile_kwargs={"literal_binds": True}))
assert "BR" in compiled
assert "import" in compiled
assert "8517%" in compiled
assert "LIMIT 5" in compiled
assert "OFFSET 5" in compiled
def test_subscription_create_and_list():
db = _FakeSubscriptionDB(rows=[SimpleNamespace(id="s1", user_email="a@example.com", target_entity_id="e1", target_hs_code=None)])
created = asyncio.run(
create_subscription(
req=SubscriptionRequest(user_email="a@example.com", target_entity_id="e1"),
api_key="test_standard_key_456",
db=db,
)
)
assert created.id
assert created.user_email == "a@example.com"
assert db.committed is True
assert db.added[0].target_entity_id == "e1"
listed = asyncio.run(list_subscriptions(user_email="a@example.com", api_key="test_standard_key_456", db=db))
assert len(listed) == 1
assert listed[0].user_email == "a@example.com"
def test_subscription_requires_target():
db = _FakeSubscriptionDB()
with pytest.raises(HTTPException) as exc_info:
asyncio.run(
create_subscription(
req=SubscriptionRequest(user_email="a@example.com"),
api_key="test_standard_key_456",
db=db,
)
)
assert exc_info.value.status_code == 400
def test_export_submission_rejects_trial_and_accepts_paid(monkeypatch):
trial_req = ExportRequest(hs_code="851712", country="BR", start_date="2025-01-01", end_date="2025-01-31", email="a@example.com")
with pytest.raises(HTTPException) as exc_info:
asyncio.run(request_async_export(req=trial_req, api_key="test_trial_key_123", tier="trial"))
assert exc_info.value.status_code == 403
delay_task = _DelayTask()
monkeypatch.setattr("apps.api.routers.export.export_data_task.delay", delay_task.delay)
paid_req = ExportRequest(hs_code="851712", country="BR", start_date="2025-01-01", end_date="2025-01-31", email="a@example.com")
result = asyncio.run(request_async_export(req=paid_req, api_key="test_standard_key_456", tier="standard"))
assert result["task_id"] == "task-123"
assert result["tier"] == "standard"
assert delay_task.calls
def test_market_share_returns_two_periods():
result = asyncio.run(get_market_share(hs_code="851712", target_country="US", api_key="test_standard_key_456"))
assert len(result) == 2
assert result[0].period == "2026-04"
def test_supply_chain_flow_rejects_malicious_hs_code():
from apps.api.routers.bi import get_supply_chain_flow
with pytest.raises(HTTPException) as exc_info:
asyncio.run(
get_supply_chain_flow(
hs_code="851712' OR 1=1 --",
year=2026,
api_key="test_standard_key_456",
)
)
assert exc_info.value.status_code == 400
def test_supply_chain_flow_uses_clickhouse_parameters(monkeypatch):
from apps.api.routers.bi import get_supply_chain_flow
class FakeResult:
result_rows = [("CHN", "BRA", 120.5)]
class FakeClient:
def __init__(self):
self.calls = []
def query(self, query, parameters=None):
self.calls.append((query, parameters))
return FakeResult()
client = FakeClient()
monkeypatch.setattr("apps.api.routers.bi.get_clickhouse_client", lambda: client)
result = asyncio.run(
get_supply_chain_flow(
hs_code="851712",
year=2026,
api_key="test_standard_key_456",
)
)
assert client.calls[0][1] == {"hs_code": "851712", "year": 2026}
assert result.edges[0].source == "CHN"
assert result.edges[0].target == "BRA"
assert {node.name for node in result.nodes} == {"CHN", "BRA"}