Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI, HTTPException, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from enum import Enum | |
| import os | |
| from sentence_transformers import SentenceTransformer | |
| model = SentenceTransformer( | |
| "dunzhang/stella_en_1.5B_v5", | |
| trust_remote_code=True, | |
| device="cpu", | |
| config_kwargs={"use_memory_efficient_attention": False, "unpad_inputs": False} | |
| ) | |
| class Enum(str, Enum): | |
| s2p_query = "s2p_query" # sentence-to-sentence | |
| s2s_query = "s2s_query" # sentence-to-passage, Q&A | |
| class Embedding(BaseModel): | |
| input: list[str] | |
| embedding_type: Enum = None | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["POST"], | |
| allow_headers=["Authorization"] | |
| ) | |
| def parse(data): | |
| result = [] | |
| for dimension in data: | |
| temp = [] | |
| for val in dimension: | |
| temp.append(round(val, 8)) | |
| result.append(temp) | |
| return result | |
| async def get_embedding(embedding: Embedding, req: Request): | |
| token = req.headers.get("Authorization") | |
| if not token or os.environ.get('token') != token[7:]: | |
| raise HTTPException(status_code=401, detail="Unauthorized.") | |
| if model == None: | |
| raise HTTPException(status_code=400, detail="Model load failed.") | |
| if embedding.embedding_type == None: | |
| data = model.encode(embedding.input).tolist() | |
| return parse(data) | |
| else: | |
| data = model.encode(embedding.input, prompt_name=embedding.embedding_type).tolist() | |
| return parse(data) | |