Spaces:
Running
Running
| from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks | |
| from fastapi.responses import FileResponse | |
| from pydantic import BaseModel | |
| from typing import Optional | |
| import os | |
| from pathlib import Path | |
| from apps.api.dependencies.auth import get_api_key, get_current_user_tier | |
| from apps.worker.export_tasks import export_data_task, EXPORT_DIR | |
| router = APIRouter(prefix="/export", tags=["Export"]) | |
| class ExportRequest(BaseModel): | |
| hs_code: Optional[str] = None | |
| country: Optional[str] = None | |
| start_date: str | |
| end_date: str | |
| email: str | |
| async def request_async_export( | |
| req: ExportRequest, | |
| api_key: str = Depends(get_api_key), | |
| tier: str = Depends(get_current_user_tier) | |
| ): | |
| """ | |
| 提交异步导出任务。 | |
| 根据用户的层级 (tier),可以限制导出的数据量或范围。 | |
| """ | |
| if tier == "trial": | |
| raise HTTPException(status_code=403, detail="Trial users cannot export data. Please upgrade your plan.") | |
| # 将任务推入 Celery 队列 | |
| task = export_data_task.delay( | |
| query_params={ | |
| "hs_code": req.hs_code, | |
| "country": req.country, | |
| "start_date": req.start_date, | |
| "end_date": req.end_date | |
| }, | |
| user_email=req.email | |
| ) | |
| return { | |
| "message": "Export task submitted successfully.", | |
| "task_id": task.id, | |
| "tier": tier | |
| } | |
| async def get_export_status(task_id: str, api_key: str = Depends(get_api_key)): | |
| """ | |
| 查询异步导出任务状态。 | |
| """ | |
| from packages.core.celery_app import celery_app | |
| task = celery_app.AsyncResult(task_id) | |
| if task.state == 'PENDING': | |
| return {"task_id": task_id, "status": "Pending"} | |
| elif task.state == 'SUCCESS': | |
| return {"task_id": task_id, "status": "Completed", "result": task.result} | |
| elif task.state == 'FAILURE': | |
| return {"task_id": task_id, "status": "Failed", "error": str(task.info)} | |
| else: | |
| return {"task_id": task_id, "status": task.state} | |
| async def download_export_file( | |
| filename: str, | |
| api_key: str = Depends(get_api_key) | |
| ): | |
| """ | |
| 下载导出文件。 | |
| 安全检查: | |
| 1. 验证 API Key | |
| 2. 防止路径遍历攻击 | |
| 3. 检查文件是否存在 | |
| """ | |
| # 防止路径遍历攻击 | |
| if ".." in filename or "/" in filename or "\\" in filename: | |
| raise HTTPException(status_code=400, detail="Invalid filename") | |
| # 只允许 CSV 文件 | |
| if not filename.endswith(".csv"): | |
| raise HTTPException(status_code=400, detail="Only CSV files are allowed") | |
| file_path = Path(EXPORT_DIR) / filename | |
| # 检查文件是否存在 | |
| if not file_path.exists() or not file_path.is_file(): | |
| raise HTTPException(status_code=404, detail="File not found or has expired") | |
| # 返回文件 | |
| return FileResponse( | |
| path=str(file_path), | |
| media_type="text/csv", | |
| filename=filename, | |
| headers={ | |
| "Content-Disposition": f"attachment; filename={filename}" | |
| } | |
| ) | |