| | import os |
| | import json |
| | from typing import List, Optional, Union |
| |
|
| | import torch |
| | from fastapi import FastAPI, Security, HTTPException |
| | from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials |
| | from pydantic import BaseModel, Field, validator |
| | from transformers import AutoModelForSequenceClassification, AutoTokenizer |
| |
|
| | app = FastAPI() |
| | security = HTTPBearer() |
| |
|
| | SK_KEY = os.environ.get("SK_KEY", "sk-aaabbbcccdddeeefffggghhhiiijjjkkk") |
| | MODEL_ID = os.environ.get("RERANK_MODEL", "Qwen/Qwen3-Reranker-4B") |
| | MAX_LENGTH = int(os.environ.get("MAX_LENGTH", "512")) |
| | DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' |
| |
|
| | model = None |
| | tokenizer = None |
| |
|
| |
|
| | class RerankRequest(BaseModel): |
| | instruction: str = Field( |
| | default="Given a web search query, retrieve relevant passages that answer the query" |
| | ) |
| | query: str |
| | documents: Union[List[str], str] |
| | top_k: int = Field(default=5, ge=1, le=50) |
| | batch_size: int = Field(default=4, ge=1, le=32) |
| | return_documents: bool = True |
| |
|
| | @validator("documents", pre=True) |
| | def ensure_list(cls, v): |
| | if isinstance(v, list): |
| | return v |
| | if isinstance(v, str): |
| | s = v.strip() |
| | if s.startswith("["): |
| | try: |
| | vv = json.loads(s) |
| | if isinstance(vv, list): |
| | return vv |
| | except Exception: |
| | pass |
| | return [v] |
| | return [str(v)] |
| |
|
| |
|
| | def _ensure_padding_token(tok, mdl): |
| | if tok.pad_token_id is None: |
| | if tok.eos_token_id is not None: |
| | tok.pad_token = tok.eos_token |
| | tok.pad_token_id = tok.eos_token_id |
| | else: |
| | tid = tok.encode(" ", add_special_tokens=False)[0] |
| | tok.pad_token_id = tid |
| | tok.pad_token = tok.decode([tid]) |
| | mdl.config.pad_token_id = tok.pad_token_id |
| |
|
| |
|
| | def _logits_to_scores(logits: torch.Tensor) -> torch.Tensor: |
| | if logits.dim() == 3: |
| | |
| | if logits.size(-1) >= 2: |
| | return logits[:, -1, 1] |
| | return logits[:, -1, 0] |
| | if logits.dim() == 2: |
| | |
| | if logits.size(-1) >= 2: |
| | return logits[:, 1] |
| | return logits[:, 0] |
| | return logits.squeeze(-1) |
| |
|
| |
|
| | @app.on_event("startup") |
| | def load_model(): |
| | global model, tokenizer |
| |
|
| | |
| | device = torch.device("cpu") |
| | torch.set_grad_enabled(False) |
| | |
| | |
| |
|
| | print(f"Loading model on CPU: {MODEL_ID}") |
| | model = AutoModelForSequenceClassification.from_pretrained( |
| | MODEL_ID, |
| | torch_dtype=torch.float32, |
| | trust_remote_code=True, |
| | ).to(device) |
| | model.eval() |
| |
|
| | tokenizer = AutoTokenizer.from_pretrained( |
| | MODEL_ID, |
| | use_fast=True, |
| | trust_remote_code=True, |
| | ) |
| |
|
| | _ensure_padding_token(tokenizer, model) |
| | print("✓ Model loaded (CPU)") |
| |
|
| |
|
| | @app.post("/v1/rerank") |
| | def rerank( |
| | req: RerankRequest, |
| | credentials: HTTPAuthorizationCredentials = Security(security), |
| | ): |
| | token = credentials.credentials |
| | if SK_KEY and token != SK_KEY: |
| | raise HTTPException(status_code=401, detail="Invalid token") |
| |
|
| | if not req.query: |
| | raise HTTPException(status_code=422, detail="query is required") |
| | if not req.documents: |
| | return {"results": []} |
| |
|
| | pairs = [ |
| | f"{req.instruction}\nQuery: {req.query}\nDocument: {doc}" |
| | for doc in req.documents |
| | ] |
| |
|
| | scores_all: List[float] = [] |
| | bs = req.batch_size |
| |
|
| | for i in range(0, len(pairs), bs): |
| | batch_pairs = pairs[i:i + bs] |
| | inputs = tokenizer( |
| | batch_pairs, |
| | padding=True, |
| | truncation=True, |
| | max_length=MAX_LENGTH, |
| | return_tensors="pt", |
| | ) |
| | |
| | for k in inputs: |
| | inputs[k] = inputs[k].to(model.device) |
| |
|
| | with torch.inference_mode(): |
| | outputs = model(**inputs) |
| | scores = _logits_to_scores(outputs.logits) |
| | scores_all.extend(scores.detach().float().cpu().tolist()) |
| |
|
| | items = [] |
| | for idx, (doc, sc) in enumerate(zip(req.documents, scores_all)): |
| | item = {"index": idx, "relevance_score": float(sc)} |
| | if req.return_documents: |
| | item["document"] = doc |
| | items.append(item) |
| |
|
| | items.sort(key=lambda x: x["relevance_score"], reverse=True) |
| | return {"model": MODEL_ID, "query": req.query, "results": items[: req.top_k]} |
| |
|
| |
|
| | if __name__ == "__main__": |
| | uvicorn.run("localrerank:app", host='0.0.0.0', port=7860, workers=1) |
| |
|