Spaces:
Running
Running
| import asyncio | |
| from types import SimpleNamespace | |
| import pytest | |
| from fastapi import HTTPException | |
| from apps.api.routers.bi import get_market_share, get_supply_chain_flow | |
| from apps.api.routers.export import ExportRequest, request_async_export | |
| from apps.api.routers.subscription import ( | |
| SubscriptionRequest, | |
| create_subscription, | |
| list_subscriptions, | |
| ) | |
| from apps.api.routers.trade import search_trade_records | |
| from apps.api.schemas.trade import TradeQueryRequest | |
| class _QueryResult: | |
| def __init__(self, total=None, records=None): | |
| self._total = total | |
| self._records = records or [] | |
| def scalar(self): | |
| return self._total | |
| def scalars(self): | |
| return self | |
| def all(self): | |
| return self._records | |
| class _FakeTradeDB: | |
| def __init__(self, records): | |
| self.records = records | |
| self.statements = [] | |
| async def execute(self, stmt): | |
| self.statements.append(stmt) | |
| if len(self.statements) == 1: | |
| return _QueryResult(total=len(self.records)) | |
| return _QueryResult(records=self.records) | |
| class _FakeSubscriptionDB: | |
| def __init__(self, rows=None): | |
| self.rows = rows or [] | |
| self.added = [] | |
| self.committed = False | |
| self.executed = [] | |
| def add(self, obj): | |
| self.added.append(obj) | |
| async def commit(self): | |
| self.committed = True | |
| async def execute(self, stmt): | |
| self.executed.append(stmt) | |
| return _QueryResult(records=self.rows) | |
| class _DelayTask: | |
| def __init__(self, task_id="task-123"): | |
| self.task_id = task_id | |
| self.calls = [] | |
| def delay(self, *args, **kwargs): | |
| self.calls.append((args, kwargs)) | |
| return SimpleNamespace(id=self.task_id) | |
| def test_trade_search_applies_filters_and_pagination(): | |
| records = [ | |
| SimpleNamespace( | |
| record_id="r1", | |
| source_record_id="raw1", | |
| source_country="BR", | |
| trade_direction="import", | |
| trade_date="2025-01-01", | |
| importer_name="Acme", | |
| exporter_name=None, | |
| hs_code="851712", | |
| product_name="phones", | |
| amount=100.0, | |
| currency="USD", | |
| weight=10.0, | |
| weight_unit="KG", | |
| origin_country="CN", | |
| destination_country="BR", | |
| departure_port=None, | |
| arrival_port=None, | |
| transport_mode="SEA", | |
| ) | |
| ] | |
| db = _FakeTradeDB(records) | |
| query = TradeQueryRequest(source_country="BR", trade_direction="import", hs_code="8517", page=2, limit=5) | |
| result = asyncio.run(search_trade_records(query=query, db=db)) | |
| assert result.total == 1 | |
| assert len(result.items) == 1 | |
| assert result.items[0].record_id == "r1" | |
| compiled = str(db.statements[1].compile(compile_kwargs={"literal_binds": True})) | |
| assert "BR" in compiled | |
| assert "import" in compiled | |
| assert "8517%" in compiled | |
| assert "LIMIT 5" in compiled | |
| assert "OFFSET 5" in compiled | |
| def test_subscription_create_and_list(): | |
| db = _FakeSubscriptionDB(rows=[SimpleNamespace(id="s1", user_email="a@example.com", target_entity_id="e1", target_hs_code=None)]) | |
| created = asyncio.run( | |
| create_subscription( | |
| req=SubscriptionRequest(user_email="a@example.com", target_entity_id="e1"), | |
| api_key="test_standard_key_456", | |
| db=db, | |
| ) | |
| ) | |
| assert created.id | |
| assert created.user_email == "a@example.com" | |
| assert db.committed is True | |
| assert db.added[0].target_entity_id == "e1" | |
| listed = asyncio.run(list_subscriptions(user_email="a@example.com", api_key="test_standard_key_456", db=db)) | |
| assert len(listed) == 1 | |
| assert listed[0].user_email == "a@example.com" | |
| def test_subscription_requires_target(): | |
| db = _FakeSubscriptionDB() | |
| with pytest.raises(HTTPException) as exc_info: | |
| asyncio.run( | |
| create_subscription( | |
| req=SubscriptionRequest(user_email="a@example.com"), | |
| api_key="test_standard_key_456", | |
| db=db, | |
| ) | |
| ) | |
| assert exc_info.value.status_code == 400 | |
| def test_export_submission_rejects_trial_and_accepts_paid(monkeypatch): | |
| trial_req = ExportRequest(hs_code="851712", country="BR", start_date="2025-01-01", end_date="2025-01-31", email="a@example.com") | |
| with pytest.raises(HTTPException) as exc_info: | |
| asyncio.run(request_async_export(req=trial_req, api_key="test_trial_key_123", tier="trial")) | |
| assert exc_info.value.status_code == 403 | |
| delay_task = _DelayTask() | |
| monkeypatch.setattr("apps.api.routers.export.export_data_task.delay", delay_task.delay) | |
| paid_req = ExportRequest(hs_code="851712", country="BR", start_date="2025-01-01", end_date="2025-01-31", email="a@example.com") | |
| result = asyncio.run(request_async_export(req=paid_req, api_key="test_standard_key_456", tier="standard")) | |
| assert result["task_id"] == "task-123" | |
| assert result["tier"] == "standard" | |
| assert delay_task.calls | |
| def test_market_share_returns_two_periods(): | |
| result = asyncio.run(get_market_share(hs_code="851712", target_country="US", api_key="test_standard_key_456")) | |
| assert len(result) == 2 | |
| assert result[0].period == "2026-04" | |
| def test_supply_chain_flow_rejects_malicious_hs_code(): | |
| from apps.api.routers.bi import get_supply_chain_flow | |
| with pytest.raises(HTTPException) as exc_info: | |
| asyncio.run( | |
| get_supply_chain_flow( | |
| hs_code="851712' OR 1=1 --", | |
| year=2026, | |
| api_key="test_standard_key_456", | |
| ) | |
| ) | |
| assert exc_info.value.status_code == 400 | |
| def test_supply_chain_flow_uses_clickhouse_parameters(monkeypatch): | |
| from apps.api.routers.bi import get_supply_chain_flow | |
| class FakeResult: | |
| result_rows = [("CHN", "BRA", 120.5)] | |
| class FakeClient: | |
| def __init__(self): | |
| self.calls = [] | |
| def query(self, query, parameters=None): | |
| self.calls.append((query, parameters)) | |
| return FakeResult() | |
| client = FakeClient() | |
| monkeypatch.setattr("apps.api.routers.bi.get_clickhouse_client", lambda: client) | |
| result = asyncio.run( | |
| get_supply_chain_flow( | |
| hs_code="851712", | |
| year=2026, | |
| api_key="test_standard_key_456", | |
| ) | |
| ) | |
| assert client.calls[0][1] == {"hs_code": "851712", "year": 2026} | |
| assert result.edges[0].source == "CHN" | |
| assert result.edges[0].target == "BRA" | |
| assert {node.name for node in result.nodes} == {"CHN", "BRA"} |