from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks from fastapi.responses import FileResponse from pydantic import BaseModel from typing import Optional import os from pathlib import Path from apps.api.dependencies.auth import get_api_key, get_current_user_tier from apps.worker.export_tasks import export_data_task, EXPORT_DIR router = APIRouter(prefix="/export", tags=["Export"]) class ExportRequest(BaseModel): hs_code: Optional[str] = None country: Optional[str] = None start_date: str end_date: str email: str @router.post("/") async def request_async_export( req: ExportRequest, api_key: str = Depends(get_api_key), tier: str = Depends(get_current_user_tier) ): """ 提交异步导出任务。 根据用户的层级 (tier),可以限制导出的数据量或范围。 """ if tier == "trial": raise HTTPException(status_code=403, detail="Trial users cannot export data. Please upgrade your plan.") # 将任务推入 Celery 队列 task = export_data_task.delay( query_params={ "hs_code": req.hs_code, "country": req.country, "start_date": req.start_date, "end_date": req.end_date }, user_email=req.email ) return { "message": "Export task submitted successfully.", "task_id": task.id, "tier": tier } @router.get("/status/{task_id}") async def get_export_status(task_id: str, api_key: str = Depends(get_api_key)): """ 查询异步导出任务状态。 """ from packages.core.celery_app import celery_app task = celery_app.AsyncResult(task_id) if task.state == 'PENDING': return {"task_id": task_id, "status": "Pending"} elif task.state == 'SUCCESS': return {"task_id": task_id, "status": "Completed", "result": task.result} elif task.state == 'FAILURE': return {"task_id": task_id, "status": "Failed", "error": str(task.info)} else: return {"task_id": task_id, "status": task.state} @router.get("/download/{filename}") async def download_export_file( filename: str, api_key: str = Depends(get_api_key) ): """ 下载导出文件。 安全检查: 1. 验证 API Key 2. 防止路径遍历攻击 3. 检查文件是否存在 """ # 防止路径遍历攻击 if ".." in filename or "/" in filename or "\\" in filename: raise HTTPException(status_code=400, detail="Invalid filename") # 只允许 CSV 文件 if not filename.endswith(".csv"): raise HTTPException(status_code=400, detail="Only CSV files are allowed") file_path = Path(EXPORT_DIR) / filename # 检查文件是否存在 if not file_path.exists() or not file_path.is_file(): raise HTTPException(status_code=404, detail="File not found or has expired") # 返回文件 return FileResponse( path=str(file_path), media_type="text/csv", filename=filename, headers={ "Content-Disposition": f"attachment; filename={filename}" } )