import os import json from typing import List, Optional, Union import torch from fastapi import FastAPI, Security, HTTPException from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from pydantic import BaseModel, Field, validator from transformers import AutoModelForSequenceClassification, AutoTokenizer app = FastAPI() security = HTTPBearer() SK_KEY = os.environ.get("SK_KEY", "sk-aaabbbcccdddeeefffggghhhiiijjjkkk") MODEL_ID = os.environ.get("RERANK_MODEL", "Qwen/Qwen3-Reranker-4B") MAX_LENGTH = int(os.environ.get("MAX_LENGTH", "512")) DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' model = None tokenizer = None class RerankRequest(BaseModel): instruction: str = Field( default="Given a web search query, retrieve relevant passages that answer the query" ) query: str documents: Union[List[str], str] top_k: int = Field(default=5, ge=1, le=50) batch_size: int = Field(default=4, ge=1, le=32) return_documents: bool = True @validator("documents", pre=True) def ensure_list(cls, v): if isinstance(v, list): return v if isinstance(v, str): s = v.strip() if s.startswith("["): try: vv = json.loads(s) if isinstance(vv, list): return vv except Exception: pass return [v] return [str(v)] def _ensure_padding_token(tok, mdl): if tok.pad_token_id is None: if tok.eos_token_id is not None: tok.pad_token = tok.eos_token tok.pad_token_id = tok.eos_token_id else: tid = tok.encode(" ", add_special_tokens=False)[0] tok.pad_token_id = tid tok.pad_token = tok.decode([tid]) mdl.config.pad_token_id = tok.pad_token_id def _logits_to_scores(logits: torch.Tensor) -> torch.Tensor: if logits.dim() == 3: # [B, T, C] if logits.size(-1) >= 2: return logits[:, -1, 1] return logits[:, -1, 0] if logits.dim() == 2: # [B, C] if logits.size(-1) >= 2: return logits[:, 1] return logits[:, 0] return logits.squeeze(-1) @app.on_event("startup") def load_model(): global model, tokenizer # 强制 CPU device = torch.device("cpu") torch.set_grad_enabled(False) # 可选:限制/设置 CPU 线程数 # torch.set_num_threads(int(os.environ.get("TORCH_NUM_THREADS", "8"))) print(f"Loading model on CPU: {MODEL_ID}") model = AutoModelForSequenceClassification.from_pretrained( MODEL_ID, torch_dtype=torch.float32, # CPU 用 float32 trust_remote_code=True, ).to(device) model.eval() tokenizer = AutoTokenizer.from_pretrained( MODEL_ID, use_fast=True, trust_remote_code=True, ) _ensure_padding_token(tokenizer, model) print("✓ Model loaded (CPU)") @app.post("/v1/rerank") def rerank( req: RerankRequest, credentials: HTTPAuthorizationCredentials = Security(security), ): token = credentials.credentials if SK_KEY and token != SK_KEY: raise HTTPException(status_code=401, detail="Invalid token") if not req.query: raise HTTPException(status_code=422, detail="query is required") if not req.documents: return {"results": []} pairs = [ f"{req.instruction}\nQuery: {req.query}\nDocument: {doc}" for doc in req.documents ] scores_all: List[float] = [] bs = req.batch_size for i in range(0, len(pairs), bs): batch_pairs = pairs[i:i + bs] inputs = tokenizer( batch_pairs, padding=True, truncation=True, max_length=MAX_LENGTH, return_tensors="pt", ) # CPU 不用 to(model.device) 也行,但保留更统一 for k in inputs: inputs[k] = inputs[k].to(model.device) with torch.inference_mode(): outputs = model(**inputs) scores = _logits_to_scores(outputs.logits) scores_all.extend(scores.detach().float().cpu().tolist()) items = [] for idx, (doc, sc) in enumerate(zip(req.documents, scores_all)): item = {"index": idx, "relevance_score": float(sc)} if req.return_documents: item["document"] = doc items.append(item) items.sort(key=lambda x: x["relevance_score"], reverse=True) return {"model": MODEL_ID, "query": req.query, "results": items[: req.top_k]} if __name__ == "__main__": uvicorn.run("localrerank:app", host='0.0.0.0', port=7860, workers=1)