| | import gradio as gr |
| | import numpy as np |
| | import os |
| | import pandas as pd |
| | import faiss |
| | from huggingface_hub import hf_hub_download |
| | import time |
| | import json |
| | import fastapi |
| | from fastapi import FastAPI, Request |
| | from fastapi.responses import JSONResponse |
| | from fastapi.middleware.cors import CORSMiddleware |
| | import uvicorn |
| | import threading |
| | import math |
| |
|
| | |
| | CACHE_DIR = "/home/user/cache" |
| | os.makedirs(CACHE_DIR, exist_ok=True) |
| |
|
| | |
| | os.environ["OMP_NUM_THREADS"] = "2" |
| | os.environ["TOKENIZERS_PARALLELISM"] = "false" |
| |
|
| | |
| | index = None |
| | metadata = None |
| |
|
| | |
| | last_updated = 0 |
| | index_refresh_interval = 300 |
| |
|
| | |
| | def refresh_index(): |
| | global index, metadata, last_updated |
| | |
| | while True: |
| | try: |
| | current_time = time.time() |
| | if current_time - last_updated > index_refresh_interval: |
| | print("🔄 检查索引更新...") |
| | |
| | |
| | METADATA_PATH = hf_hub_download( |
| | repo_id="GOGO198/GOGO_rag_index", |
| | filename="metadata.csv", |
| | cache_dir=CACHE_DIR, |
| | token=os.getenv("HF_TOKEN"), |
| | force_download=True |
| | ) |
| | |
| | file_mtime = os.path.getmtime(METADATA_PATH) |
| | if file_mtime > last_updated: |
| | print("📥 检测到新索引,重新加载...") |
| | |
| | INDEX_PATH = hf_hub_download( |
| | repo_id="GOGO198/GOGO_rag_index", |
| | filename="faiss_index.bin", |
| | cache_dir=CACHE_DIR, |
| | token=os.getenv("HF_TOKEN"), |
| | force_download=True |
| | ) |
| | new_index = faiss.read_index(INDEX_PATH) |
| | new_metadata = pd.read_csv(METADATA_PATH) |
| | |
| | index = new_index |
| | metadata = new_metadata |
| | last_updated = file_mtime |
| | |
| | print(f"✅ 索引更新完成 | 记录数: {len(metadata)}") |
| | |
| | except Exception as e: |
| | print(f"索引更新失败: {str(e)}") |
| | |
| | time.sleep(30) |
| |
|
| | def load_resources(): |
| | """加载所有必要资源(768维专用)""" |
| | global index, metadata |
| |
|
| | |
| | lock_files = [f for f in os.listdir(CACHE_DIR) if f.endswith('.lock')] |
| | for lock_file in lock_files: |
| | try: |
| | os.remove(os.path.join(CACHE_DIR, lock_file)) |
| | print(f"🧹 清理锁文件: {lock_file}") |
| | except: |
| | pass |
| |
|
| | if index is None or metadata is None: |
| | print("🔄 正在加载所有资源...") |
| |
|
| | |
| | if index is None: |
| | print("📥 正在下载FAISS索引...") |
| | try: |
| | INDEX_PATH = hf_hub_download( |
| | repo_id="GOGO198/GOGO_rag_index", |
| | filename="faiss_index.bin", |
| | cache_dir=CACHE_DIR, |
| | token=os.getenv("HF_TOKEN") |
| | ) |
| | index = faiss.read_index(INDEX_PATH) |
| | |
| | if index.d != 768: |
| | raise ValueError("❌ 索引维度错误:预期768维") |
| | |
| | print(f"✅ FAISS索引加载完成 | 维度: {index.d}") |
| | except Exception as e: |
| | print(f"❌ FAISS索引加载失败: {str(e)}") |
| | raise |
| |
|
| | |
| | if metadata is None: |
| | print("📄 正在下载元数据...") |
| | try: |
| | METADATA_PATH = hf_hub_download( |
| | repo_id="GOGO198/GOGO_rag_index", |
| | filename="metadata.csv", |
| | cache_dir=CACHE_DIR, |
| | token=os.getenv("HF_TOKEN") |
| | ) |
| | metadata = pd.read_csv(METADATA_PATH) |
| | print(f"✅ 元数据加载完成 | 记录数: {len(metadata)}") |
| | except Exception as e: |
| | print(f"❌ 元数据加载失败: {str(e)}") |
| | raise |
| |
|
| | |
| | threading.Thread(target=refresh_index, daemon=True).start() |
| |
|
| | |
| | load_resources() |
| |
|
| | def sanitize_floats(obj): |
| | if isinstance(obj, float): |
| | if math.isnan(obj) or math.isinf(obj): |
| | return 0.0 |
| | return obj |
| | elif isinstance(obj, dict): |
| | return {k: sanitize_floats(v) for k, v in obj.items()} |
| | elif isinstance(obj, list): |
| | return [sanitize_floats(x) for x in obj] |
| | else: |
| | return obj |
| |
|
| | |
| | return { |
| | "status": "success", |
| | "results": sanitize_floats(results) |
| | } |
| |
|
| | def predict(vector): |
| | try: |
| | print(f"接收向量: {vector[:3]}... (长度: {len(vector)})") |
| | |
| | |
| | query_vector = np.array(vector).astype('float32').reshape(1, -1) |
| | |
| | |
| | k = min(3, index.ntotal) |
| | if k == 0: |
| | return { |
| | "status": "success", |
| | "results": [], |
| | "message": "索引为空" |
| | } |
| | |
| | print(f"执行FAISS搜索 | k={k}") |
| | D, I = index.search(query_vector, k=k) |
| | |
| | |
| | print(f"搜索结果索引: {I[0]}") |
| | print(f"搜索距离: {D[0]}") |
| | |
| | |
| | results = [] |
| | for i in range(k): |
| | try: |
| | idx = I[0][i] |
| | distance = D[0][i] |
| | |
| | |
| | if not np.isfinite(distance) or distance < 0: |
| | distance = 100.0 |
| | |
| | |
| | confidence = 1 / (1 + distance) |
| | confidence = max(0.0, min(1.0, confidence)) |
| | |
| | |
| | distance = float(distance) |
| | confidence = float(confidence) |
| | |
| | result = { |
| | "source": metadata.iloc[idx]["source"], |
| | "content": metadata.iloc[idx].get("content", ""), |
| | "confidence": confidence, |
| | "distance": distance |
| | } |
| | results.append(result) |
| | except Exception as e: |
| | |
| | results.append({ |
| | "error": str(e), |
| | "confidence": 0.5, |
| | "distance": 0.0 |
| | }) |
| |
|
| | return { |
| | "status": "success", |
| | "results": sanitize_floats(results) |
| | } |
| | except Exception as e: |
| | |
| | return { |
| | "status": "error", |
| | "message": f"服务器内部错误: {str(e)}", |
| | "details": sanitize_floats({"trace": traceback.format_exc()}) |
| | } |
| | |
| | |
| | app = FastAPI() |
| |
|
| | |
| | app.add_middleware( |
| | CORSMiddleware, |
| | allow_origins=["*"], |
| | allow_credentials=True, |
| | allow_methods=["*"], |
| | allow_headers=["*"], |
| | ) |
| |
|
| | @app.post("/predict") |
| | async def api_predict(request: Request): |
| | """API预测端点""" |
| | try: |
| | data = await request.json() |
| | vector = data.get("vector") |
| | |
| | if not vector or not isinstance(vector, list): |
| | return JSONResponse( |
| | status_code=400, |
| | content={"status": "error", "message": "无效输入: 需要向量数据"} |
| | ) |
| | |
| | result = predict(vector) |
| | return JSONResponse(content=result) |
| | |
| | except Exception as e: |
| | return JSONResponse( |
| | status_code=500, |
| | content={ |
| | "status": "error", |
| | "message": f"服务器内部错误了: {str(e)}" |
| | } |
| | ) |
| |
|
| | |
| | if __name__ == "__main__": |
| | |
| | print("="*50) |
| | print("Space启动完成 | 准备接收请求") |
| | print(f"索引维度: {index.d}") |
| | print(f"元数据记录数: {len(metadata)}") |
| | print("="*50) |
| |
|
| | |
| | uvicorn.run(app, host="0.0.0.0", port=7860) |