File size: 3,588 Bytes
1c0323e
 
9bd55b2
1c0323e
9b01576
 
842daa2
9b01576
25bfd3b
9b01576
1c0323e
 
9bd55b2
512acbc
1c0323e
 
 
 
 
 
 
 
 
 
 
 
 
 
9b01576
1c0323e
9b01576
 
7bec0b4
9b01576
 
 
 
25bfd3b
 
 
9b01576
1c0323e
9b01576
 
 
 
1c0323e
 
 
 
 
 
 
 
 
 
 
25bfd3b
 
 
1c0323e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25bfd3b
 
9bd55b2
1c0323e
9b01576
 
 
1c0323e
 
 
9b01576
 
 
1c0323e
9b01576
 
 
1c0323e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# app.py — FastAPI embeddings service using PyTorch BioBERT
# Works on Hugging Face Spaces (CPU Basic, free)

import os
from typing import List, Optional

from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel

import torch
from transformers import AutoTokenizer, AutoModel

HF_MODEL_ID = os.environ.get("HF_MODEL_ID", "monologg/biobert_v1.1_pubmed").strip()
MAX_LEN = int(os.environ.get("MAX_LEN", "128"))
TORCH_THREADS = int(os.environ.get("TORCH_THREADS", "1"))

torch.set_num_threads(TORCH_THREADS)

# --------- Load model & tokenizer (PyTorch) ----------
tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_ID)
model = AutoModel.from_pretrained(HF_MODEL_ID)
model.eval()  # inference mode
DEVICE = "cpu"
model.to(DEVICE)

# --------- FastAPI ----------
app = FastAPI(title="BioBERT (PyTorch) Embeddings API", version="1.0")

# CORS (relax; tighten in production)
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=False,
    allow_methods=["GET", "POST", "OPTIONS"],
    allow_headers=["*"],
)

class EmbReq(BaseModel):
    input: str
    max_len: Optional[int] = None
    pooling: Optional[str] = "cls"  # "cls" or "mean"

class BatchEmbReq(BaseModel):
    inputs: List[str]
    max_len: Optional[int] = None
    pooling: Optional[str] = "cls"  # "cls" or "mean"

@app.get("/")
def root():
    return {
        "name": "BioBERT Embeddings (PyTorch)",
        "model": HF_MODEL_ID,
        "device": DEVICE,
        "endpoints": ["/health", "/v1/embeddings", "/v1/embeddings/batch"],
        "hint": "POST to /v1/embeddings with {'input': 'your text'}",
    }

@app.get("/health")
def health():
    return {"ok": True, "model": HF_MODEL_ID, "device": DEVICE}

def _pool(outputs, inputs, pooling: str):
    """
    pooling="cls": use CLS (pooler_output if present, else hidden_state[:,0])
    pooling="mean": mean of token embeddings (mask-aware)
    """
    if pooling == "mean":
        last = outputs.last_hidden_state  # [B,T,H]
        mask = inputs["attention_mask"].unsqueeze(-1).type_as(last)  # [B,T,1]
        summed = (last * mask).sum(dim=1)
        counts = mask.sum(dim=1).clamp(min=1e-9)
        return summed / counts
    # cls
    if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None:
        return outputs.pooler_output
    return outputs.last_hidden_state[:, 0, :]  # CLS token

def _embed(texts: List[str], max_len: int, pooling: str) -> List[List[float]]:
    enc = tokenizer(
        texts,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=max_len,
    )
    enc = {k: v.to(DEVICE) for k, v in enc.items()}
    with torch.no_grad():
        outputs = model(**enc)
        vecs = _pool(outputs, enc, pooling=pooling)
    return vecs.cpu().numpy().tolist()

@app.post("/v1/embeddings")
def embeddings(req: EmbReq):
    text = (req.input or "").strip()
    if not text:
        return {"embedding": [], "dim": 0}
    L = int(req.max_len or MAX_LEN)
    pooling = (req.pooling or "cls").lower()
    vec = _embed([text], L, pooling)[0]
    return {"embedding": vec, "dim": len(vec), "pooling": pooling}

@app.post("/v1/embeddings/batch")
def embeddings_batch(req: BatchEmbReq):
    items = [str(t).strip() for t in (req.inputs or []) if str(t).strip()]
    if not items:
        return {"embeddings": [], "dim": 0}
    L = int(req.max_len or MAX_LEN)
    pooling = (req.pooling or "cls").lower()
    vecs = _embed(items, L, pooling)
    return {"embeddings": vecs, "dim": len(vecs[0]), "pooling": pooling}