File size: 3,496 Bytes
c8c5741
 
 
 
 
 
 
 
 
 
 
 
e418691
e16d5ab
c8c5741
 
 
 
 
1335071
 
c8c5741
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bdae94f
 
 
 
c8c5741
 
 
 
 
 
bdae94f
 
 
 
 
 
e16d5ab
bdae94f
 
 
 
 
 
 
c8c5741
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os
from fastapi import FastAPI, HTTPException, Depends, Header
from pydantic import BaseModel
from typing import List, Optional
from huggingface_hub import hf_hub_download
from llama_cpp import Llama
import uvicorn

app = FastAPI()

# --- 配置区域 ---
# 替换为真实的社区仓库名和 GGUF 文件名
MODEL_REPO = "mradermacher/Qwen3-Reranker-0.6B-GGUF" 
MODEL_FILE = "Qwen3-Reranker-0.6B.Q5_K_M.gguf"
# 设置你的专属 API KEY 防止别人滥用你的免费资源
MY_API_KEY = os.getenv("API_KEY", "1qazxsw2")

# --- 1. 下载并加载模型 ---
print("Downloading model from Hugging Face Hub...")
hf_token = os.getenv("HF_TOKEN")
model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE, token=hf_token)

print("Loading model via llama.cpp...")
llm = Llama(
    model_path=model_path,
    n_ctx=2048,      # 上下文窗口,4B 模型在 16G 内存下可设为 2048 或 4096
    n_threads=2,     # 对应免费 Space 的 2 vCPU
    verbose=False
)

# --- 2. 定义数据结构 ---
class RerankRequest(BaseModel):
    query: str
    documents: List[str]
    top_n: Optional[int] = None

class ModelList(BaseModel):
    object: str = "list"
    data: list

# --- 3. 鉴权依赖 ---
async def verify_api_key(authorization: str = Header(None)):
    if not authorization or authorization != f"Bearer {MY_API_KEY}":
        raise HTTPException(status_code=401, detail="Unauthorized: Invalid API Key")

# --- 4. 核心接口逻辑 ---
@app.get("/v1/models")
async def list_models():
    return {
        "object": "list",
        "data": [
            {
                "id": "qwen3-reranker-0.6b", # 这里填你在 Cherry Studio 里想看到的名字
                "object": "model",
                "created": 1700000000,
                "owned_by": "huggingface"
            }
        ]
    }

@app.post("/v1/rerank", dependencies=[Depends(verify_api_key)])
async def rerank(request: RerankRequest):
    query = request.query
    documents = request.documents
    
    results = []
    for idx, doc in enumerate(documents):
        # 注意:这里需要根据具体的 Qwen Reranker prompt 格式调整。
        # 大多基于 LLM 的 Reranker 要求输出特定的 prompt,让模型打分
        # 这里使用一种通用的相关性问答 Prompt 示例:
        prompt = f"Query: {query}\nDocument: {doc}\nScore the relevance from 0 to 100:"
        
        # 让模型生成很短的回复(例如分数数字)
        response = llm(
            prompt,
            max_tokens=2,
            stop=["\n"],
            echo=False
        )
        
        try:
            # 尝试从模型输出中解析数字分数
            text_output = response['choices'][0]['text'].strip()
            score = float(text_output) if text_output.isdigit() else 0.0
        except:
            score = 0.0 # 解析失败给 0 分

        results.append({
            "index": idx,
            "document": doc,
            "relevance_score": score
        })
    
    # 按照得分从高到低排序
    results.sort(key=lambda x: x["relevance_score"], reverse=True)
    
    # 如果用户请求了 top_n,则截断
    if request.top_n is not None:
        results = results[:request.top_n]
        
    return {"results": results}

# 根路由探活
@app.get("/")
def read_root():
    return {"status": "running", "model": MODEL_FILE}

if __name__ == "__main__":
    # HF Spaces 默认公开 7860 端口
    uvicorn.run(app, host="0.0.0.0", port=7860)