bgeM3Node / app.py
VietCat's picture
fix permission issue for cache, and remove pooling
c58569a
raw
history blame contribute delete
782 Bytes
from fastapi import FastAPI, Request
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModel
import torch
app = FastAPI()
# Load model
model_name = "BAAI/bge-m3"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
class InputText(BaseModel):
text: str
@app.get("/")
def root():
return {"message": "BAAI/bge-m3 embedding API is running."}
@app.post("/embed")
def get_embedding(data: InputText):
inputs = tokenizer(data.text, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
outputs = model(**inputs)
# Get CLS token or use pooling method
embedding = outputs.last_hidden_state[:, 0, :].squeeze().tolist()
return {"embedding": embedding}