Spaces:
Running
Running
| from celery import shared_task | |
| from packages.core.logger import app_logger | |
| import time | |
| import os | |
| import csv | |
| import asyncio | |
| from datetime import datetime, timedelta | |
| from pathlib import Path | |
| from typing import Optional | |
| # 导出文件存储路径(可配置为 S3/OSS 等对象存储) | |
| EXPORT_DIR = os.getenv("EXPORT_DIR", "/tmp/customs_exports") | |
| EXPORT_URL_PREFIX = os.getenv("EXPORT_URL_PREFIX", "http://localhost:8000/api/v1/export/download") | |
| # 文件保留时长(天) | |
| EXPORT_RETENTION_DAYS = int(os.getenv("EXPORT_RETENTION_DAYS", "7")) | |
| def _run_async(coro): | |
| """在同步任务中运行异步函数""" | |
| try: | |
| loop = asyncio.get_event_loop() | |
| except RuntimeError: | |
| loop = asyncio.new_event_loop() | |
| asyncio.set_event_loop(loop) | |
| if loop.is_running(): | |
| new_loop = asyncio.new_event_loop() | |
| asyncio.set_event_loop(new_loop) | |
| return new_loop.run_until_complete(coro) | |
| else: | |
| return loop.run_until_complete(coro) | |
| async def _generate_export_file(query_params: dict, filename: str) -> tuple[str, int]: | |
| """ | |
| 查询数据库并生成导出文件 | |
| Returns: | |
| (file_path, record_count) | |
| """ | |
| from packages.core.database import AsyncSessionLocal | |
| from packages.core.models import StandardTradeRecord | |
| from sqlalchemy import select | |
| # 确保导出目录存在 | |
| Path(EXPORT_DIR).mkdir(parents=True, exist_ok=True) | |
| file_path = os.path.join(EXPORT_DIR, filename) | |
| async with AsyncSessionLocal() as session: | |
| # 构建查询 | |
| stmt = select(StandardTradeRecord) | |
| if query_params.get("country"): | |
| stmt = stmt.where(StandardTradeRecord.source_country == query_params["country"]) | |
| if query_params.get("hs_code"): | |
| stmt = stmt.where(StandardTradeRecord.hs_code.like(f"{query_params['hs_code']}%")) | |
| if query_params.get("start_date"): | |
| stmt = stmt.where(StandardTradeRecord.trade_date >= query_params["start_date"]) | |
| if query_params.get("end_date"): | |
| stmt = stmt.where(StandardTradeRecord.trade_date <= query_params["end_date"]) | |
| # 执行查询并写入 CSV | |
| result = await session.stream(stmt) | |
| record_count = 0 | |
| with open(file_path, 'w', newline='', encoding='utf-8') as csvfile: | |
| writer = None | |
| async for row in result: | |
| record = row[0] | |
| if writer is None: | |
| # 初始化 CSV writer | |
| fieldnames = [ | |
| 'record_id', 'source_country', 'trade_direction', 'trade_date', | |
| 'importer_name', 'exporter_name', 'hs_code', 'product_name', | |
| 'amount', 'currency', 'weight', 'weight_unit', | |
| 'origin_country', 'destination_country', 'departure_port', | |
| 'arrival_port', 'transport_mode' | |
| ] | |
| writer = csv.DictWriter(csvfile, fieldnames=fieldnames) | |
| writer.writeheader() | |
| # 写入记录 | |
| writer.writerow({ | |
| 'record_id': record.record_id, | |
| 'source_country': record.source_country, | |
| 'trade_direction': record.trade_direction, | |
| 'trade_date': record.trade_date.isoformat() if record.trade_date else '', | |
| 'importer_name': record.importer_name or '', | |
| 'exporter_name': record.exporter_name or '', | |
| 'hs_code': record.hs_code or '', | |
| 'product_name': record.product_name or '', | |
| 'amount': record.amount if record.amount is not None else '', | |
| 'currency': record.currency or '', | |
| 'weight': record.weight if record.weight is not None else '', | |
| 'weight_unit': record.weight_unit or '', | |
| 'origin_country': record.origin_country or '', | |
| 'destination_country': record.destination_country or '', | |
| 'departure_port': record.departure_port or '', | |
| 'arrival_port': record.arrival_port or '', | |
| 'transport_mode': record.transport_mode or '' | |
| }) | |
| record_count += 1 | |
| return file_path, record_count | |
| async def _send_export_notification(user_email: str, file_url: str, record_count: int): | |
| """发送导出完成通知(邮件)""" | |
| try: | |
| import httpx | |
| from packages.core.config import settings | |
| # 如果配置了邮件服务,发送通知 | |
| sendgrid_api_key = os.getenv("SENDGRID_API_KEY") | |
| if sendgrid_api_key: | |
| payload = { | |
| "personalizations": [{ | |
| "to": [{"email": user_email}], | |
| "subject": "海关数据导出完成" | |
| }], | |
| "from": {"email": "noreply@customs-data.com", "name": "海关数据系统"}, | |
| "content": [{ | |
| "type": "text/html", | |
| "value": f""" | |
| <h2>您的数据导出已完成</h2> | |
| <p>导出记录数:{record_count}</p> | |
| <p>下载链接:<a href="{file_url}">{file_url}</a></p> | |
| <p>提示:下载链接将在 {EXPORT_RETENTION_DAYS} 天后过期。</p> | |
| """ | |
| }] | |
| } | |
| async with httpx.AsyncClient() as client: | |
| response = await client.post( | |
| "https://api.sendgrid.com/v3/mail/send", | |
| headers={ | |
| "Authorization": f"Bearer {sendgrid_api_key}", | |
| "Content-Type": "application/json" | |
| }, | |
| json=payload, | |
| timeout=10.0 | |
| ) | |
| if response.status_code == 202: | |
| app_logger.info(f"Export notification email sent to {user_email}") | |
| else: | |
| app_logger.error(f"Failed to send email: {response.text}") | |
| else: | |
| app_logger.info(f"Export ready for {user_email}: {file_url} ({record_count} records)") | |
| except Exception as e: | |
| app_logger.error(f"Error sending export notification: {e}") | |
| def export_data_task(query_params: dict, user_email: str): | |
| """ | |
| 异步大结果集导出任务 | |
| 1. 查询数据并生成 CSV 文件 | |
| 2. 生成下载链接 | |
| 3. 发送邮件通知用户 | |
| """ | |
| app_logger.info(f"Starting async export task for {user_email} with params: {query_params}") | |
| try: | |
| # 生成唯一文件名 | |
| timestamp = int(time.time()) | |
| filename = f"export_{timestamp}_{user_email.replace('@', '_').replace('.', '_')}.csv" | |
| # 执行导出 | |
| file_path, record_count = _run_async(_generate_export_file(query_params, filename)) | |
| # 生成下载链接 | |
| file_url = f"{EXPORT_URL_PREFIX}/{filename}" | |
| # 发送通知 | |
| _run_async(_send_export_notification(user_email, file_url, record_count)) | |
| app_logger.info(f"Export task completed. File: {file_path}, Records: {record_count}") | |
| return { | |
| "status": "success", | |
| "file_url": file_url, | |
| "file_path": file_path, | |
| "record_count": record_count, | |
| "user_email": user_email, | |
| "expires_at": (datetime.now() + timedelta(days=EXPORT_RETENTION_DAYS)).isoformat() | |
| } | |
| except Exception as e: | |
| app_logger.error(f"Export task failed: {e}") | |
| return { | |
| "status": "failed", | |
| "error": str(e), | |
| "user_email": user_email | |
| } | |
| def cleanup_expired_exports(): | |
| """ | |
| 定期清理过期的导出文件 | |
| """ | |
| app_logger.info("Starting cleanup of expired export files") | |
| try: | |
| export_dir = Path(EXPORT_DIR) | |
| if not export_dir.exists(): | |
| return | |
| now = datetime.now() | |
| retention_seconds = EXPORT_RETENTION_DAYS * 24 * 3600 | |
| deleted_count = 0 | |
| for file_path in export_dir.glob("*.csv"): | |
| # 检查文件修改时间 | |
| file_mtime = datetime.fromtimestamp(file_path.stat().st_mtime) | |
| age_seconds = (now - file_mtime).total_seconds() | |
| if age_seconds > retention_seconds: | |
| file_path.unlink() | |
| deleted_count += 1 | |
| app_logger.info(f"Deleted expired export file: {file_path.name}") | |
| app_logger.info(f"Cleanup completed. Deleted {deleted_count} expired files") | |
| return {"deleted_count": deleted_count} | |
| except Exception as e: | |
| app_logger.error(f"Error during export cleanup: {e}") | |
| return {"error": str(e)} | |