from celery import shared_task from packages.core.logger import app_logger import time import os import csv import asyncio from datetime import datetime, timedelta from pathlib import Path from typing import Optional # 导出文件存储路径(可配置为 S3/OSS 等对象存储) EXPORT_DIR = os.getenv("EXPORT_DIR", "/tmp/customs_exports") EXPORT_URL_PREFIX = os.getenv("EXPORT_URL_PREFIX", "http://localhost:8000/api/v1/export/download") # 文件保留时长(天) EXPORT_RETENTION_DAYS = int(os.getenv("EXPORT_RETENTION_DAYS", "7")) def _run_async(coro): """在同步任务中运行异步函数""" try: loop = asyncio.get_event_loop() except RuntimeError: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) if loop.is_running(): new_loop = asyncio.new_event_loop() asyncio.set_event_loop(new_loop) return new_loop.run_until_complete(coro) else: return loop.run_until_complete(coro) async def _generate_export_file(query_params: dict, filename: str) -> tuple[str, int]: """ 查询数据库并生成导出文件 Returns: (file_path, record_count) """ from packages.core.database import AsyncSessionLocal from packages.core.models import StandardTradeRecord from sqlalchemy import select # 确保导出目录存在 Path(EXPORT_DIR).mkdir(parents=True, exist_ok=True) file_path = os.path.join(EXPORT_DIR, filename) async with AsyncSessionLocal() as session: # 构建查询 stmt = select(StandardTradeRecord) if query_params.get("country"): stmt = stmt.where(StandardTradeRecord.source_country == query_params["country"]) if query_params.get("hs_code"): stmt = stmt.where(StandardTradeRecord.hs_code.like(f"{query_params['hs_code']}%")) if query_params.get("start_date"): stmt = stmt.where(StandardTradeRecord.trade_date >= query_params["start_date"]) if query_params.get("end_date"): stmt = stmt.where(StandardTradeRecord.trade_date <= query_params["end_date"]) # 执行查询并写入 CSV result = await session.stream(stmt) record_count = 0 with open(file_path, 'w', newline='', encoding='utf-8') as csvfile: writer = None async for row in result: record = row[0] if writer is None: # 初始化 CSV writer fieldnames = [ 'record_id', 'source_country', 'trade_direction', 'trade_date', 'importer_name', 'exporter_name', 'hs_code', 'product_name', 'amount', 'currency', 'weight', 'weight_unit', 'origin_country', 'destination_country', 'departure_port', 'arrival_port', 'transport_mode' ] writer = csv.DictWriter(csvfile, fieldnames=fieldnames) writer.writeheader() # 写入记录 writer.writerow({ 'record_id': record.record_id, 'source_country': record.source_country, 'trade_direction': record.trade_direction, 'trade_date': record.trade_date.isoformat() if record.trade_date else '', 'importer_name': record.importer_name or '', 'exporter_name': record.exporter_name or '', 'hs_code': record.hs_code or '', 'product_name': record.product_name or '', 'amount': record.amount if record.amount is not None else '', 'currency': record.currency or '', 'weight': record.weight if record.weight is not None else '', 'weight_unit': record.weight_unit or '', 'origin_country': record.origin_country or '', 'destination_country': record.destination_country or '', 'departure_port': record.departure_port or '', 'arrival_port': record.arrival_port or '', 'transport_mode': record.transport_mode or '' }) record_count += 1 return file_path, record_count async def _send_export_notification(user_email: str, file_url: str, record_count: int): """发送导出完成通知(邮件)""" try: import httpx from packages.core.config import settings # 如果配置了邮件服务,发送通知 sendgrid_api_key = os.getenv("SENDGRID_API_KEY") if sendgrid_api_key: payload = { "personalizations": [{ "to": [{"email": user_email}], "subject": "海关数据导出完成" }], "from": {"email": "noreply@customs-data.com", "name": "海关数据系统"}, "content": [{ "type": "text/html", "value": f"""
导出记录数:{record_count}
下载链接:{file_url}
提示:下载链接将在 {EXPORT_RETENTION_DAYS} 天后过期。
""" }] } async with httpx.AsyncClient() as client: response = await client.post( "https://api.sendgrid.com/v3/mail/send", headers={ "Authorization": f"Bearer {sendgrid_api_key}", "Content-Type": "application/json" }, json=payload, timeout=10.0 ) if response.status_code == 202: app_logger.info(f"Export notification email sent to {user_email}") else: app_logger.error(f"Failed to send email: {response.text}") else: app_logger.info(f"Export ready for {user_email}: {file_url} ({record_count} records)") except Exception as e: app_logger.error(f"Error sending export notification: {e}") @shared_task(name="export_data_task") def export_data_task(query_params: dict, user_email: str): """ 异步大结果集导出任务 1. 查询数据并生成 CSV 文件 2. 生成下载链接 3. 发送邮件通知用户 """ app_logger.info(f"Starting async export task for {user_email} with params: {query_params}") try: # 生成唯一文件名 timestamp = int(time.time()) filename = f"export_{timestamp}_{user_email.replace('@', '_').replace('.', '_')}.csv" # 执行导出 file_path, record_count = _run_async(_generate_export_file(query_params, filename)) # 生成下载链接 file_url = f"{EXPORT_URL_PREFIX}/{filename}" # 发送通知 _run_async(_send_export_notification(user_email, file_url, record_count)) app_logger.info(f"Export task completed. File: {file_path}, Records: {record_count}") return { "status": "success", "file_url": file_url, "file_path": file_path, "record_count": record_count, "user_email": user_email, "expires_at": (datetime.now() + timedelta(days=EXPORT_RETENTION_DAYS)).isoformat() } except Exception as e: app_logger.error(f"Export task failed: {e}") return { "status": "failed", "error": str(e), "user_email": user_email } @shared_task(name="cleanup_expired_exports") def cleanup_expired_exports(): """ 定期清理过期的导出文件 """ app_logger.info("Starting cleanup of expired export files") try: export_dir = Path(EXPORT_DIR) if not export_dir.exists(): return now = datetime.now() retention_seconds = EXPORT_RETENTION_DAYS * 24 * 3600 deleted_count = 0 for file_path in export_dir.glob("*.csv"): # 检查文件修改时间 file_mtime = datetime.fromtimestamp(file_path.stat().st_mtime) age_seconds = (now - file_mtime).total_seconds() if age_seconds > retention_seconds: file_path.unlink() deleted_count += 1 app_logger.info(f"Deleted expired export file: {file_path.name}") app_logger.info(f"Cleanup completed. Deleted {deleted_count} expired files") return {"deleted_count": deleted_count} except Exception as e: app_logger.error(f"Error during export cleanup: {e}") return {"error": str(e)}