import asyncio from types import SimpleNamespace import pytest from fastapi import HTTPException from apps.api.routers.bi import get_market_share, get_supply_chain_flow from apps.api.routers.export import ExportRequest, request_async_export from apps.api.routers.subscription import ( SubscriptionRequest, create_subscription, list_subscriptions, ) from apps.api.routers.trade import search_trade_records from apps.api.schemas.trade import TradeQueryRequest class _QueryResult: def __init__(self, total=None, records=None): self._total = total self._records = records or [] def scalar(self): return self._total def scalars(self): return self def all(self): return self._records class _FakeTradeDB: def __init__(self, records): self.records = records self.statements = [] async def execute(self, stmt): self.statements.append(stmt) if len(self.statements) == 1: return _QueryResult(total=len(self.records)) return _QueryResult(records=self.records) class _FakeSubscriptionDB: def __init__(self, rows=None): self.rows = rows or [] self.added = [] self.committed = False self.executed = [] def add(self, obj): self.added.append(obj) async def commit(self): self.committed = True async def execute(self, stmt): self.executed.append(stmt) return _QueryResult(records=self.rows) class _DelayTask: def __init__(self, task_id="task-123"): self.task_id = task_id self.calls = [] def delay(self, *args, **kwargs): self.calls.append((args, kwargs)) return SimpleNamespace(id=self.task_id) def test_trade_search_applies_filters_and_pagination(): records = [ SimpleNamespace( record_id="r1", source_record_id="raw1", source_country="BR", trade_direction="import", trade_date="2025-01-01", importer_name="Acme", exporter_name=None, hs_code="851712", product_name="phones", amount=100.0, currency="USD", weight=10.0, weight_unit="KG", origin_country="CN", destination_country="BR", departure_port=None, arrival_port=None, transport_mode="SEA", ) ] db = _FakeTradeDB(records) query = TradeQueryRequest(source_country="BR", trade_direction="import", hs_code="8517", page=2, limit=5) result = asyncio.run(search_trade_records(query=query, db=db)) assert result.total == 1 assert len(result.items) == 1 assert result.items[0].record_id == "r1" compiled = str(db.statements[1].compile(compile_kwargs={"literal_binds": True})) assert "BR" in compiled assert "import" in compiled assert "8517%" in compiled assert "LIMIT 5" in compiled assert "OFFSET 5" in compiled def test_subscription_create_and_list(): db = _FakeSubscriptionDB(rows=[SimpleNamespace(id="s1", user_email="a@example.com", target_entity_id="e1", target_hs_code=None)]) created = asyncio.run( create_subscription( req=SubscriptionRequest(user_email="a@example.com", target_entity_id="e1"), api_key="test_standard_key_456", db=db, ) ) assert created.id assert created.user_email == "a@example.com" assert db.committed is True assert db.added[0].target_entity_id == "e1" listed = asyncio.run(list_subscriptions(user_email="a@example.com", api_key="test_standard_key_456", db=db)) assert len(listed) == 1 assert listed[0].user_email == "a@example.com" def test_subscription_requires_target(): db = _FakeSubscriptionDB() with pytest.raises(HTTPException) as exc_info: asyncio.run( create_subscription( req=SubscriptionRequest(user_email="a@example.com"), api_key="test_standard_key_456", db=db, ) ) assert exc_info.value.status_code == 400 def test_export_submission_rejects_trial_and_accepts_paid(monkeypatch): trial_req = ExportRequest(hs_code="851712", country="BR", start_date="2025-01-01", end_date="2025-01-31", email="a@example.com") with pytest.raises(HTTPException) as exc_info: asyncio.run(request_async_export(req=trial_req, api_key="test_trial_key_123", tier="trial")) assert exc_info.value.status_code == 403 delay_task = _DelayTask() monkeypatch.setattr("apps.api.routers.export.export_data_task.delay", delay_task.delay) paid_req = ExportRequest(hs_code="851712", country="BR", start_date="2025-01-01", end_date="2025-01-31", email="a@example.com") result = asyncio.run(request_async_export(req=paid_req, api_key="test_standard_key_456", tier="standard")) assert result["task_id"] == "task-123" assert result["tier"] == "standard" assert delay_task.calls def test_market_share_returns_two_periods(): result = asyncio.run(get_market_share(hs_code="851712", target_country="US", api_key="test_standard_key_456")) assert len(result) == 2 assert result[0].period == "2026-04" def test_supply_chain_flow_rejects_malicious_hs_code(): from apps.api.routers.bi import get_supply_chain_flow with pytest.raises(HTTPException) as exc_info: asyncio.run( get_supply_chain_flow( hs_code="851712' OR 1=1 --", year=2026, api_key="test_standard_key_456", ) ) assert exc_info.value.status_code == 400 def test_supply_chain_flow_uses_clickhouse_parameters(monkeypatch): from apps.api.routers.bi import get_supply_chain_flow class FakeResult: result_rows = [("CHN", "BRA", 120.5)] class FakeClient: def __init__(self): self.calls = [] def query(self, query, parameters=None): self.calls.append((query, parameters)) return FakeResult() client = FakeClient() monkeypatch.setattr("apps.api.routers.bi.get_clickhouse_client", lambda: client) result = asyncio.run( get_supply_chain_flow( hs_code="851712", year=2026, api_key="test_standard_key_456", ) ) assert client.calls[0][1] == {"hs_code": "851712", "year": 2026} assert result.edges[0].source == "CHN" assert result.edges[0].target == "BRA" assert {node.name for node in result.nodes} == {"CHN", "BRA"}