frida / app.py
militarybearz's picture
Update app.py
69928c1 verified
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from typing import List, Optional
import numpy as np
from sentence_transformers import SentenceTransformer
app = FastAPI(title="FRIDA Embedding API", version="1.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
MODEL_NAME = "ai-forever/FRIDA"
model = SentenceTransformer(MODEL_NAME)
EMBED_DIM = model.get_sentence_embedding_dimension()
SUPPORTED_PROMPTS = [
"search_query",
"search_document",
"paraphrase",
"categorize",
"categorize_sentiment",
"categorize_topic",
"categorize_entailment",
]
class EmbedRequest(BaseModel):
texts: List[str] = Field(..., description="Список текстов")
prompt_name: Optional[str] = Field("search_document", description="FRIDA prompt_name")
class EmbedResponse(BaseModel):
embeddings: List[List[float]]
dim: int
@app.get("/health")
def health():
return {"status": "ok"}
@app.get("/metadata")
def metadata():
return {
"model": MODEL_NAME,
"embedding_dim": EMBED_DIM,
"pooling": "cls",
"prompts_supported": SUPPORTED_PROMPTS,
}
@app.post("/embed", response_model=EmbedResponse)
def embed(req: EmbedRequest):
if not req.texts:
raise HTTPException(status_code=400, detail="texts must be non-empty")
prompt = req.prompt_name or "search_document"
if prompt not in SUPPORTED_PROMPTS:
raise HTTPException(status_code=400, detail=f"Unsupported prompt_name: {prompt}")
vectors = model.encode(
req.texts,
convert_to_numpy=True,
prompt_name=prompt,
normalize_embeddings=True,
batch_size=min(16, max(1, len(req.texts))),
show_progress_bar=False,
).astype(np.float32)
return {"embeddings": vectors.tolist(), "dim": int(vectors.shape[1])}