| from __future__ import annotations |
|
|
| from typing import Any |
|
|
| from fastapi import FastAPI, HTTPException |
| from pydantic import BaseModel, Field |
|
|
| from .runtime import Qwen3AneRerankRuntime |
|
|
|
|
| class RerankRequest(BaseModel): |
| query: str |
| documents: list[str] |
| model: str | None = None |
| top_n: int | None = Field(default=None, ge=1) |
| return_documents: bool = False |
| instruction: str | None = None |
| user: str | None = None |
|
|
|
|
| def create_app(runtime: Qwen3AneRerankRuntime, default_model_id: str | None = None) -> FastAPI: |
| app = FastAPI(title="Qwen3 ANE Reranker Service", version="0.1.0") |
|
|
| @app.get("/health") |
| def health() -> dict[str, Any]: |
| return { |
| "ok": True, |
| "task": "rerank", |
| "model": default_model_id or runtime.manifest.model_name, |
| "profiles": [ |
| { |
| "id": p.entry.profile_id, |
| "batch_size": p.entry.batch_size, |
| "seq_len": p.entry.seq_len, |
| } |
| for p in runtime.profiles |
| ], |
| } |
|
|
| @app.post("/rerank") |
| @app.post("/v1/rerank") |
| def rerank(req: RerankRequest) -> dict[str, Any]: |
| try: |
| if req.query == "": |
| raise ValueError("query must not be empty") |
| if not req.documents: |
| raise ValueError("documents must not be empty") |
| if any(doc == "" for doc in req.documents): |
| raise ValueError("documents must not contain empty strings") |
|
|
| results, prompt_tokens = runtime.rerank( |
| query=req.query, |
| documents=req.documents, |
| top_n=req.top_n, |
| instruction=req.instruction, |
| ) |
|
|
| data = [] |
| for row in results: |
| item = { |
| "object": "rerank_result", |
| "index": row["index"], |
| "relevance_score": row["relevance_score"], |
| } |
| if req.return_documents: |
| item["document"] = req.documents[row["index"]] |
| data.append(item) |
|
|
| model_name = req.model or default_model_id or runtime.manifest.model_name |
| return { |
| "object": "list", |
| "data": data, |
| "model": model_name, |
| "usage": { |
| "prompt_tokens": int(prompt_tokens), |
| "total_tokens": int(prompt_tokens), |
| }, |
| } |
| except ValueError as exc: |
| raise HTTPException(status_code=400, detail=str(exc)) from exc |
| except RuntimeError as exc: |
| raise HTTPException(status_code=500, detail=str(exc)) from exc |
|
|
| return app |
|
|