import os import asyncio import json import time import datetime import httpx from typing import List, Optional from fastapi import FastAPI, UploadFile, File, Form, Depends, Query, HTTPException, BackgroundTasks, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from fastapi.responses import FileResponse from sqlalchemy.orm import Session from database import get_db, init_db, SessionLocal from models import ( VectorMatchTask, VectorDataset, VectorDataRow, VectorEmbedding, MatchResult, ) from schemas import ( TaskCreate, TaskDetail, TaskProgress, TaskListItem, MatchResultItem, MatchResultPage, SourceWithCandidates, CandidateDetail, UploadResponse, SettingItem, SettingsResponse, DatasetInfo, ) from services.excel_service import save_upload_file, get_sheet_info, parse_excel_rows from services.match_service import run_match_task app = FastAPI(title="VectorMatch API", version="1.0.0") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) import logging, traceback from starlette.requests import Request from starlette.responses import JSONResponse logger = logging.getLogger("uvicorn.error") @app.exception_handler(Exception) async def global_exception_handler(request: Request, exc: Exception): logger.error(f"Unhandled error on {request.method} {request.url}:\n{traceback.format_exc()}") return JSONResponse(status_code=500, content={"detail": str(exc)}) # ─── 健康状态缓存 ───────────────────────────────────────────────────────── _health_cache = { "result": {"embedding_ok": False, "reranker_ok": False, "embedding_model": "", "reranker_model": "", "reranker_enabled": False, "has_api_key": False}, "updated_at": 0, } _HEALTH_TTL = 30 # 缓存有效期(秒) async def _do_health_check(): """执行真正的 API 探活,更新缓存""" import services.embedding_service as es api_key = es.SILICONFLOW_API_KEY result = { "embedding_ok": False, "reranker_ok": False, "embedding_model": es.EMBEDDING_MODEL, "reranker_model": es.RERANKER_MODEL, "reranker_enabled": es.RERANKER_ENABLED, "has_api_key": bool(api_key), } if api_key: try: async with httpx.AsyncClient(timeout=5.0, proxies={}) as client: try: emb_resp = await client.post( "https://api.siliconflow.cn/v1/embeddings", headers={"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}, json={"model": es.EMBEDDING_MODEL, "input": ["ping"]}, ) result["embedding_ok"] = emb_resp.status_code == 200 except Exception: pass if es.RERANKER_ENABLED: try: rerank_resp = await client.post( "https://api.siliconflow.cn/v1/rerank", headers={"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}, json={"model": es.RERANKER_MODEL, "query": "ping", "documents": ["pong"], "top_n": 1}, ) result["reranker_ok"] = rerank_resp.status_code == 200 except Exception: pass except Exception: pass _health_cache["result"] = result _health_cache["updated_at"] = time.time() return result async def _health_polling_loop(): """后台定时探活循环""" while True: try: await _do_health_check() except Exception: pass await asyncio.sleep(_HEALTH_TTL) @app.on_event("startup") async def startup(): init_db() # 启动后台健康检查循环 asyncio.create_task(_health_polling_loop()) # ─── Upload Excel ─────────────────────────────────────────────────────────── @app.post("/api/upload", response_model=UploadResponse) async def upload_excel( file: UploadFile = File(...), dataset_role: str = Form("source"), db: Session = Depends(get_db), ): content = await file.read() filepath = save_upload_file(content, file.filename) info = get_sheet_info(filepath) dataset = VectorDataset( name=file.filename, file_name=file.filename, dataset_role=dataset_role, data_scope="task", ) db.add(dataset) db.commit() db.refresh(dataset) return UploadResponse( dataset_id=dataset.id, file_name=file.filename, sheet_names=info["sheet_names"], columns=info["columns"], all_columns=info.get("all_columns", info["columns"]), ) # ─── Configure dataset (sheet, fields) ───────────────────────────────────── @app.post("/api/dataset/{dataset_id}/configure") def configure_dataset( dataset_id: int, sheet_name: str = Form(...), vector_fields: str = Form(...), db: Session = Depends(get_db), ): dataset = db.query(VectorDataset).get(dataset_id) if not dataset: raise HTTPException(404, "Dataset not found") dataset.sheet_name = sheet_name dataset.vector_fields = vector_fields db.commit() fields = json.loads(vector_fields) import os filepath = os.path.join( os.path.dirname(__file__), "data", "uploads", dataset.file_name ) rows = parse_excel_rows(filepath, sheet_name, fields) for row_data in rows: dr = VectorDataRow( dataset_id=dataset.id, dataset_role=dataset.dataset_role, data_scope=dataset.data_scope, row_number=row_data["row_number"], raw_text=row_data["raw_text"], text_hash=row_data["text_hash"], field_values=row_data["field_values"], ) db.add(dr) dataset.row_count = len(rows) db.commit() return {"status": "ok", "row_count": len(rows)} # ─── Get dataset info ────────────────────────────────────────────────────── @app.get("/api/dataset/{dataset_id}", response_model=DatasetInfo) def get_dataset(dataset_id: int, db: Session = Depends(get_db)): dataset = db.query(VectorDataset).get(dataset_id) if not dataset: raise HTTPException(404, "Dataset not found") return dataset # ─── Create & start task ─────────────────────────────────────────────────── @app.post("/api/task", response_model=TaskDetail) async def create_task( background_tasks: BackgroundTasks, source_dataset_id: int = Form(...), target_dataset_id: int = Form(...), match_mode: str = Form("two_file"), top_k: int = Form(10), rerank_top_k: int = Form(3), min_threshold: float = Form(0.70), candidate_scope: str = Form("current_task_target"), db: Session = Depends(get_db), ): now = datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=8))) task_code = now.strftime("%Y%m%d%H%M%S") + f"{now.microsecond // 1000:03d}" src = db.query(VectorDataset).get(source_dataset_id) tgt = db.query(VectorDataset).get(target_dataset_id) if not src or not tgt: raise HTTPException(400, "Source or target dataset not found") task = VectorMatchTask( task_code=task_code, match_mode=match_mode, candidate_scope=candidate_scope, source_dataset_id=source_dataset_id, target_dataset_id=target_dataset_id, top_k=top_k, rerank_top_k=rerank_top_k, min_threshold=min_threshold, status="pending", ) db.add(task) db.commit() db.refresh(task) src.task_id = task.id tgt.task_id = task.id db.query(VectorDataRow).filter(VectorDataRow.dataset_id == src.id).update({"task_id": task.id}) db.query(VectorDataRow).filter(VectorDataRow.dataset_id == tgt.id).update({"task_id": task.id}) db.commit() background_tasks.add_task(_run_task_in_background, task.id) db.refresh(task) return task def _run_task_in_background(task_id: int): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) loop.run_until_complete(run_match_task(task_id, SessionLocal)) loop.close() def _get_alive_task(db: Session, task_id: int) -> Optional[VectorMatchTask]: """未软删除的任务(is_delete=0)。""" task = db.query(VectorMatchTask).get(task_id) if not task or (task.is_delete or 0) == 1: return None return task # ─── Task progress ───────────────────────────────────────────────────────── @app.get("/api/task/{task_id}/progress", response_model=TaskProgress) def get_task_progress(task_id: int, db: Session = Depends(get_db)): task = _get_alive_task(db, task_id) if not task: raise HTTPException(404, "Task not found") return task # ─── Task detail ─────────────────────────────────────────────────────────── @app.get("/api/task/{task_id}", response_model=TaskDetail) def get_task_detail(task_id: int, db: Session = Depends(get_db)): task = _get_alive_task(db, task_id) if not task: raise HTTPException(404, "Task not found") return task # ─── Task list ───────────────────────────────────────────────────────────── @app.get("/api/tasks", response_model=List[TaskListItem]) def list_tasks( scope: str = Query("active", description="active=未归档, archived=仅归档, deleted=回收站"), db: Session = Depends(get_db), ): if scope not in ("active", "archived", "deleted"): raise HTTPException(400, "scope 须为 active、archived 或 deleted") q = db.query(VectorMatchTask) if scope == "deleted": q = q.filter(VectorMatchTask.is_delete == 1) else: q = q.filter(VectorMatchTask.is_delete == 0) if scope == "archived": q = q.filter(VectorMatchTask.is_archived == 1) else: q = q.filter(VectorMatchTask.is_archived == 0) tasks = q.order_by(VectorMatchTask.created_time.desc()).all() result = [] for t in tasks: src_name = t.source_dataset.name if t.source_dataset else None tgt_name = t.target_dataset.name if t.target_dataset else None result.append(TaskListItem( id=t.id, task_code=t.task_code, match_mode=t.match_mode, candidate_scope=t.candidate_scope, source_dataset_name=src_name, target_dataset_name=tgt_name, status=t.status, is_archived=t.is_archived or 0, is_delete=t.is_delete or 0, created_time=t.created_time, )) return result @app.post("/api/task/{task_id}/archive") def archive_task(task_id: int, db: Session = Depends(get_db)): task = _get_alive_task(db, task_id) if not task: raise HTTPException(404, "Task not found") task.is_archived = 1 db.commit() return {"status": "ok"} @app.post("/api/task/{task_id}/unarchive") def unarchive_task(task_id: int, db: Session = Depends(get_db)): task = _get_alive_task(db, task_id) if not task: raise HTTPException(404, "Task not found") task.is_archived = 0 db.commit() return {"status": "ok"} @app.delete("/api/task/{task_id}") def delete_task(task_id: int, db: Session = Depends(get_db)): """软删除:is_delete=1,数据仍保留在库中。""" task = _get_alive_task(db, task_id) if not task: raise HTTPException(404, "Task not found") task.is_delete = 1 db.commit() return {"status": "ok"} @app.post("/api/task/{task_id}/restore") def restore_task(task_id: int, db: Session = Depends(get_db)): """从回收站恢复。""" task = db.query(VectorMatchTask).get(task_id) if not task or (task.is_delete or 0) != 1: raise HTTPException(404, "Task not found or not deleted") task.is_delete = 0 db.commit() return {"status": "ok"} # ─── Match results ───────────────────────────────────────────────────────── @app.get("/api/task/{task_id}/results", response_model=MatchResultPage) def get_task_results( task_id: int, page: int = Query(1, ge=1), page_size: int = Query(20, ge=1, le=100), search: Optional[str] = None, level: Optional[str] = None, sort: str = "score_desc", db: Session = Depends(get_db), ): if not _get_alive_task(db, task_id): raise HTTPException(404, "Task not found") query = ( db.query(MatchResult) .filter(MatchResult.task_id == task_id, MatchResult.rank == 1) ) if level and level != "all": query = query.filter(MatchResult.match_level == level) if sort == "score_desc": query = query.order_by(MatchResult.similarity_score.desc()) elif sort == "score_asc": query = query.order_by(MatchResult.similarity_score.asc()) else: query = query.order_by(MatchResult.source_row_id) total = query.count() results = query.offset((page - 1) * page_size).limit(page_size).all() items = [] for r in results: src_row = db.query(VectorDataRow).get(r.source_row_id) tgt_row = db.query(VectorDataRow).get(r.target_row_id) if search: if search.lower() not in (src_row.raw_text or "").lower() and \ search.lower() not in (tgt_row.raw_text or "").lower(): continue items.append(MatchResultItem( id=r.id, source_row_id=r.source_row_id, source_row_number=src_row.row_number if src_row else 0, source_text=src_row.raw_text if src_row else "", target_text=tgt_row.raw_text if tgt_row else "", similarity_score=r.similarity_score, rerank_score=r.rerank_score, match_level=r.match_level or "", candidate_scope=r.candidate_scope, is_confirmed=r.is_confirmed, )) return MatchResultPage(items=items, total=total, page=page, page_size=page_size) # ─── Candidate details for a source row ──────────────────────────────────── @app.get("/api/task/{task_id}/candidates/{source_row_id}", response_model=SourceWithCandidates) def get_candidates(task_id: int, source_row_id: int, db: Session = Depends(get_db)): if not _get_alive_task(db, task_id): raise HTTPException(404, "Task not found") src_row = db.query(VectorDataRow).get(source_row_id) if not src_row: raise HTTPException(404, "Source row not found") results = ( db.query(MatchResult) .filter(MatchResult.task_id == task_id, MatchResult.source_row_id == source_row_id) .order_by(MatchResult.rank) .all() ) candidates = [] for r in results: tgt_row = db.query(VectorDataRow).get(r.target_row_id) candidates.append(CandidateDetail( rank=r.rank, rerank_rank=r.rerank_rank, target_row_id=r.target_row_id, target_text=tgt_row.raw_text if tgt_row else "", similarity_score=r.similarity_score, rerank_score=r.rerank_score, match_level=r.match_level or "", dataset_role="target", candidate_scope=r.candidate_scope, data_row_id=tgt_row.id if tgt_row else 0, is_confirmed=r.is_confirmed, )) return SourceWithCandidates( source_row_id=src_row.id, source_text=src_row.raw_text, source_row_number=src_row.row_number, dataset_role=src_row.dataset_role, data_row_id=src_row.id, candidates=candidates, ) # ─── Confirm match ───────────────────────────────────────────────────────── @app.post("/api/result/{result_id}/confirm") def confirm_match(result_id: int, db: Session = Depends(get_db)): result = db.query(MatchResult).get(result_id) if not result: raise HTTPException(404, "Result not found") result.is_confirmed = 1 db.commit() return {"status": "ok"} @app.post("/api/result/{result_id}/ignore") def ignore_match(result_id: int, db: Session = Depends(get_db)): result = db.query(MatchResult).get(result_id) if not result: raise HTTPException(404, "Result not found") result.is_confirmed = -1 db.commit() return {"status": "ok"} # ─── Settings (read/write .env) ──────────────────────────────────────────── _backend_dir = os.path.dirname(os.path.abspath(__file__)) _env_local = os.path.join(_backend_dir, ".env.local") ENV_PATH = _env_local if os.path.exists(_env_local) else os.path.join(_backend_dir, ".env") def _read_env() -> dict: result = {} if os.path.exists(ENV_PATH): with open(ENV_PATH, "r", encoding="utf-8") as f: for line in f: line = line.strip() if line and not line.startswith("#") and "=" in line: k, v = line.split("=", 1) result[k.strip()] = v.strip() return result def _write_env(settings: dict): with open(ENV_PATH, "w", encoding="utf-8") as f: for k, v in settings.items(): f.write(f"{k}={v}\n") @app.get("/api/settings", response_model=SettingsResponse) def get_settings(): return SettingsResponse(settings=_read_env()) @app.post("/api/settings") async def update_settings(items: List[SettingItem]): current = _read_env() for item in items: current[item.key] = item.value _write_env(current) # 保存后自动重载环境变量,无需手动重启 from dotenv import load_dotenv load_dotenv(ENV_PATH, override=True) # 同步更新 embedding_service 模块中的配置常量 import services.embedding_service as es es.SILICONFLOW_API_KEY = os.environ.get("SILICONFLOW_API_KEY", "") es.EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "BAAI/bge-m3") es.EMBEDDING_DIM = int(os.environ.get("EMBEDDING_DIM", "1024")) es.RERANKER_MODEL = os.environ.get("RERANKER_MODEL", "Qwen/Qwen3-VL-Reranker-8B") es.RERANKER_ENABLED = os.environ.get("RERANKER_ENABLED", "true").lower() == "true" # 立即刷新健康缓存,前端下次请求即可拿到最新状态 await _do_health_check() return {"status": "ok", "message": "已保存,配置已实时生效"} # ─── 健康检查(返回后端缓存,秒级响应)──────────────────────────────────────── @app.get("/api/health") async def health_check(force: bool = False): """返回缓存的健康状态,force=true 时立即刷新""" if force or time.time() - _health_cache["updated_at"] > _HEALTH_TTL: await _do_health_check() return _health_cache["result"] # ─── Export results ──────────────────────────────────────────────────────── @app.get("/api/task/{task_id}/export") def export_results(task_id: int, db: Session = Depends(get_db)): import io import openpyxl from openpyxl.styles import Font, Alignment, PatternFill from fastapi.responses import StreamingResponse task = _get_alive_task(db, task_id) if not task: raise HTTPException(404, "Task not found") results = ( db.query(MatchResult) .filter(MatchResult.task_id == task_id) .order_by(MatchResult.source_row_id, MatchResult.rank) .all() ) from openpyxl.styles import Font, PatternFill, Alignment wb = openpyxl.Workbook() ws = wb.active ws.title = "匹配结果" headers = ["源行号", "源数据内容", "候选排名", "目标候选内容", "相似度(%)", "精排分", "匹配等级", "候选来源"] ws.append(headers) # Header styling header_font = Font(bold=True, color="FFFFFF") header_fill = PatternFill(start_color="1F4E79", end_color="1F4E79", fill_type="solid") for cell in ws[1]: cell.font = header_font cell.fill = header_fill cell.alignment = Alignment(horizontal="center", vertical="center") level_map = {"high": "高度匹配", "possible": "可能匹配", "low_confidence": "低置信", "no_match": "不匹配"} scope_map = {"current_task_target": "目标候选集", "history": "历史数据", "standard": "标准库"} for r in results: src = db.query(VectorDataRow).get(r.source_row_id) tgt = db.query(VectorDataRow).get(r.target_row_id) ws.append([ src.row_number if src else "", src.raw_text if src else "", r.rank, tgt.raw_text if tgt else "", round(r.similarity_score * 100, 2), round(r.rerank_score, 4) if r.rerank_score is not None else "", level_map.get(r.match_level, r.match_level), scope_map.get(r.candidate_scope, r.candidate_scope or ""), ]) # Column widths col_widths = [8, 40, 10, 40, 12, 12, 12, 14] for i, w in enumerate(col_widths, 1): ws.column_dimensions[chr(64 + i)].width = w output = io.BytesIO() wb.save(output) output.seek(0) return StreamingResponse( output, media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", headers={"Content-Disposition": f"attachment; filename=match_result_{task.task_code}.xlsx"}, ) # Serve frontend static files _static_dir = os.path.join(os.path.dirname(__file__), "static") if os.path.isdir(os.path.join(_static_dir, "assets")): app.mount("/assets", StaticFiles(directory=os.path.join(_static_dir, "assets")), name="assets") @app.get("/{full_path:path}") async def serve_frontend(full_path: str): """Catch-all: serve index.html for SPA routing""" file_path = os.path.join(_static_dir, full_path) if full_path and os.path.isfile(file_path): return FileResponse(file_path) index = os.path.join(_static_dir, "index.html") if os.path.isfile(index): return FileResponse(index) return {"detail": "Not found"} if __name__ == "__main__": import uvicorn uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)