reranker / app.py
JamesK123's picture
Update app.py
e16d5ab verified
import os
from fastapi import FastAPI, HTTPException, Depends, Header
from pydantic import BaseModel
from typing import List, Optional
from huggingface_hub import hf_hub_download
from llama_cpp import Llama
import uvicorn
app = FastAPI()
# --- 配置区域 ---
# 替换为真实的社区仓库名和 GGUF 文件名
MODEL_REPO = "mradermacher/Qwen3-Reranker-0.6B-GGUF"
MODEL_FILE = "Qwen3-Reranker-0.6B.Q5_K_M.gguf"
# 设置你的专属 API KEY 防止别人滥用你的免费资源
MY_API_KEY = os.getenv("API_KEY", "1qazxsw2")
# --- 1. 下载并加载模型 ---
print("Downloading model from Hugging Face Hub...")
hf_token = os.getenv("HF_TOKEN")
model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE, token=hf_token)
print("Loading model via llama.cpp...")
llm = Llama(
model_path=model_path,
n_ctx=2048, # 上下文窗口,4B 模型在 16G 内存下可设为 2048 或 4096
n_threads=2, # 对应免费 Space 的 2 vCPU
verbose=False
)
# --- 2. 定义数据结构 ---
class RerankRequest(BaseModel):
query: str
documents: List[str]
top_n: Optional[int] = None
class ModelList(BaseModel):
object: str = "list"
data: list
# --- 3. 鉴权依赖 ---
async def verify_api_key(authorization: str = Header(None)):
if not authorization or authorization != f"Bearer {MY_API_KEY}":
raise HTTPException(status_code=401, detail="Unauthorized: Invalid API Key")
# --- 4. 核心接口逻辑 ---
@app.get("/v1/models")
async def list_models():
return {
"object": "list",
"data": [
{
"id": "qwen3-reranker-0.6b", # 这里填你在 Cherry Studio 里想看到的名字
"object": "model",
"created": 1700000000,
"owned_by": "huggingface"
}
]
}
@app.post("/v1/rerank", dependencies=[Depends(verify_api_key)])
async def rerank(request: RerankRequest):
query = request.query
documents = request.documents
results = []
for idx, doc in enumerate(documents):
# 注意:这里需要根据具体的 Qwen Reranker prompt 格式调整。
# 大多基于 LLM 的 Reranker 要求输出特定的 prompt,让模型打分
# 这里使用一种通用的相关性问答 Prompt 示例:
prompt = f"Query: {query}\nDocument: {doc}\nScore the relevance from 0 to 100:"
# 让模型生成很短的回复(例如分数数字)
response = llm(
prompt,
max_tokens=2,
stop=["\n"],
echo=False
)
try:
# 尝试从模型输出中解析数字分数
text_output = response['choices'][0]['text'].strip()
score = float(text_output) if text_output.isdigit() else 0.0
except:
score = 0.0 # 解析失败给 0 分
results.append({
"index": idx,
"document": doc,
"relevance_score": score
})
# 按照得分从高到低排序
results.sort(key=lambda x: x["relevance_score"], reverse=True)
# 如果用户请求了 top_n,则截断
if request.top_n is not None:
results = results[:request.top_n]
return {"results": results}
# 根路由探活
@app.get("/")
def read_root():
return {"status": "running", "model": MODEL_FILE}
if __name__ == "__main__":
# HF Spaces 默认公开 7860 端口
uvicorn.run(app, host="0.0.0.0", port=7860)