File size: 4,676 Bytes
f71153e
9fd268b
 
 
 
f71153e
 
9fd268b
 
f71153e
 
 
 
9fd268b
195b7ea
d627081
9fd268b
 
 
 
 
 
 
 
 
 
 
 
d627081
 
9fd268b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f71153e
9fd268b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f71153e
9fd268b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f71153e
9fd268b
 
 
 
f71153e
9fd268b
 
 
 
 
 
f71153e
9fd268b
 
4e680d3
f71153e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import os
import json
from typing import List, Optional, Union

import torch
from fastapi import FastAPI, Security, HTTPException
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from pydantic import BaseModel, Field, validator
from transformers import AutoModelForSequenceClassification, AutoTokenizer

app = FastAPI()
security = HTTPBearer()

SK_KEY = os.environ.get("SK_KEY", "sk-aaabbbcccdddeeefffggghhhiiijjjkkk")
MODEL_ID = os.environ.get("RERANK_MODEL", "Qwen/Qwen3-Reranker-4B")
MAX_LENGTH = int(os.environ.get("MAX_LENGTH", "512"))
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

model = None
tokenizer = None


class RerankRequest(BaseModel):
    instruction: str = Field(
        default="Given a web search query, retrieve relevant passages that answer the query"
    )
    query: str
    documents: Union[List[str], str]
    top_k: int = Field(default=5, ge=1, le=50)
    batch_size: int = Field(default=4, ge=1, le=32)
    return_documents: bool = True

    @validator("documents", pre=True)
    def ensure_list(cls, v):
        if isinstance(v, list):
            return v
        if isinstance(v, str):
            s = v.strip()
            if s.startswith("["):
                try:
                    vv = json.loads(s)
                    if isinstance(vv, list):
                        return vv
                except Exception:
                    pass
            return [v]
        return [str(v)]


def _ensure_padding_token(tok, mdl):
    if tok.pad_token_id is None:
        if tok.eos_token_id is not None:
            tok.pad_token = tok.eos_token
            tok.pad_token_id = tok.eos_token_id
        else:
            tid = tok.encode(" ", add_special_tokens=False)[0]
            tok.pad_token_id = tid
            tok.pad_token = tok.decode([tid])
    mdl.config.pad_token_id = tok.pad_token_id


def _logits_to_scores(logits: torch.Tensor) -> torch.Tensor:
    if logits.dim() == 3:
        # [B, T, C]
        if logits.size(-1) >= 2:
            return logits[:, -1, 1]
        return logits[:, -1, 0]
    if logits.dim() == 2:
        # [B, C]
        if logits.size(-1) >= 2:
            return logits[:, 1]
        return logits[:, 0]
    return logits.squeeze(-1)


@app.on_event("startup")
def load_model():
    global model, tokenizer

    # 强制 CPU
    device = torch.device("cpu")
    torch.set_grad_enabled(False)
    # 可选:限制/设置 CPU 线程数
    # torch.set_num_threads(int(os.environ.get("TORCH_NUM_THREADS", "8")))

    print(f"Loading model on CPU: {MODEL_ID}")
    model = AutoModelForSequenceClassification.from_pretrained(
        MODEL_ID,
        torch_dtype=torch.float32,     # CPU 用 float32
        trust_remote_code=True,
    ).to(device)
    model.eval()

    tokenizer = AutoTokenizer.from_pretrained(
        MODEL_ID,
        use_fast=True,
        trust_remote_code=True,
    )

    _ensure_padding_token(tokenizer, model)
    print("✓ Model loaded (CPU)")


@app.post("/v1/rerank")
def rerank(
    req: RerankRequest,
    credentials: HTTPAuthorizationCredentials = Security(security),
):
    token = credentials.credentials
    if SK_KEY and token != SK_KEY:
        raise HTTPException(status_code=401, detail="Invalid token")

    if not req.query:
        raise HTTPException(status_code=422, detail="query is required")
    if not req.documents:
        return {"results": []}

    pairs = [
        f"{req.instruction}\nQuery: {req.query}\nDocument: {doc}"
        for doc in req.documents
    ]

    scores_all: List[float] = []
    bs = req.batch_size

    for i in range(0, len(pairs), bs):
        batch_pairs = pairs[i:i + bs]
        inputs = tokenizer(
            batch_pairs,
            padding=True,
            truncation=True,
            max_length=MAX_LENGTH,
            return_tensors="pt",
        )
        # CPU 不用 to(model.device) 也行,但保留更统一
        for k in inputs:
            inputs[k] = inputs[k].to(model.device)

        with torch.inference_mode():
            outputs = model(**inputs)
            scores = _logits_to_scores(outputs.logits)
            scores_all.extend(scores.detach().float().cpu().tolist())

    items = []
    for idx, (doc, sc) in enumerate(zip(req.documents, scores_all)):
        item = {"index": idx, "relevance_score": float(sc)}
        if req.return_documents:
            item["document"] = doc
        items.append(item)

    items.sort(key=lambda x: x["relevance_score"], reverse=True)
    return {"model": MODEL_ID, "query": req.query, "results": items[: req.top_k]}


if __name__ == "__main__":
    uvicorn.run("localrerank:app", host='0.0.0.0', port=7860, workers=1)