Spaces:
Paused
Paused
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field | |
| from typing import List, Optional | |
| import numpy as np | |
| from sentence_transformers import SentenceTransformer | |
| app = FastAPI(title="FRIDA Embedding API", version="1.0") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| MODEL_NAME = "ai-forever/FRIDA" | |
| model = SentenceTransformer(MODEL_NAME) | |
| EMBED_DIM = model.get_sentence_embedding_dimension() | |
| SUPPORTED_PROMPTS = [ | |
| "search_query", | |
| "search_document", | |
| "paraphrase", | |
| "categorize", | |
| "categorize_sentiment", | |
| "categorize_topic", | |
| "categorize_entailment", | |
| ] | |
| class EmbedRequest(BaseModel): | |
| texts: List[str] = Field(..., description="Список текстов") | |
| prompt_name: Optional[str] = Field("search_document", description="FRIDA prompt_name") | |
| class EmbedResponse(BaseModel): | |
| embeddings: List[List[float]] | |
| dim: int | |
| def health(): | |
| return {"status": "ok"} | |
| def metadata(): | |
| return { | |
| "model": MODEL_NAME, | |
| "embedding_dim": EMBED_DIM, | |
| "pooling": "cls", | |
| "prompts_supported": SUPPORTED_PROMPTS, | |
| } | |
| def embed(req: EmbedRequest): | |
| if not req.texts: | |
| raise HTTPException(status_code=400, detail="texts must be non-empty") | |
| prompt = req.prompt_name or "search_document" | |
| if prompt not in SUPPORTED_PROMPTS: | |
| raise HTTPException(status_code=400, detail=f"Unsupported prompt_name: {prompt}") | |
| vectors = model.encode( | |
| req.texts, | |
| convert_to_numpy=True, | |
| prompt_name=prompt, | |
| normalize_embeddings=True, | |
| batch_size=min(16, max(1, len(req.texts))), | |
| show_progress_bar=False, | |
| ).astype(np.float32) | |
| return {"embeddings": vectors.tolist(), "dim": int(vectors.shape[1])} |