Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI | |
| from sentence_transformers import SentenceTransformer | |
| import pickle | |
| import os | |
| from pydantic import BaseModel | |
| import numpy as np | |
| from typing import List | |
| app = FastAPI( | |
| title="SBERT Embedding API", | |
| description="API for generating sentence embeddings using SBERT", | |
| version="1.0" | |
| ) | |
| # Load model (this will be cached after first load) | |
| model_name = 'taghyan/model' | |
| model = SentenceTransformer(model_name) | |
| # Embedding cache setup | |
| embedding_file = 'embeddings_sbert.pkl' | |
| class TextRequest(BaseModel): | |
| text: str | |
| class TextsRequest(BaseModel): | |
| texts: List[str] | |
| class EmbeddingResponse(BaseModel): | |
| embedding: List[float] | |
| class EmbeddingsResponse(BaseModel): | |
| embeddings: List[List[float]] | |
| def read_root(): | |
| return {"message": "SBERT Embedding Service"} | |
| async def embed_text(request: TextRequest): | |
| """Generate embedding for a single text""" | |
| embedding = model.encode(request.text, convert_to_numpy=True).tolist() | |
| return {"embedding": embedding} | |
| async def embed_texts(request: TextsRequest): | |
| """Generate embeddings for multiple texts""" | |
| embeddings = model.encode(request.texts, show_progress_bar=True, convert_to_numpy=True).tolist() | |
| return {"embeddings": embeddings} | |
| async def update_cache(request: TextsRequest): | |
| """Update the embedding cache with new texts""" | |
| if os.path.exists(embedding_file): | |
| with open(embedding_file, 'rb') as f: | |
| existing_embeddings = pickle.load(f) | |
| else: | |
| existing_embeddings = [] | |
| new_embeddings = model.encode(request.texts, show_progress_bar=True) | |
| updated_embeddings = existing_embeddings + new_embeddings.tolist() | |
| with open(embedding_file, 'wb') as f: | |
| pickle.dump(updated_embeddings, f) | |
| return {"message": f"Cache updated with {len(request.texts)} new embeddings", "total_embeddings": len(updated_embeddings)} |