File size: 1,450 Bytes
8425d1d 27199b6 8425d1d 27199b6 8425d1d 084ac18 8425d1d 27199b6 8425d1d 27199b6 2cd1c19 27199b6 8425d1d 27199b6 8425d1d 27199b6 8425d1d 27199b6 8425d1d 27199b6 8425d1d 27199b6 8425d1d 27199b6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
import numpy as np
# Initialize the FastAPI app
app = FastAPI()
model = SentenceTransformer("Alibaba-NLP/gte-base-en-v1.5", trust_remote_code=True)
# Define the request body schema
class TextInput(BaseModel):
text: str
# Home route
@app.get("/")
async def home():
return {"message": "Welcome to my dashboard"}
# Define the API endpoint
@app.post("/embed")
async def generate_embedding(text_input: TextInput):
"""
Generate a 768-dimensional embedding for the input text.
Returns the embedding in a structured format with rounded values.
"""
try:
# Generate the embedding
embedding = model.encode(text_input.text, convert_to_tensor=True).cpu().numpy()
# Round embedding values to 2 decimal places
rounded_embedding = np.round(embedding, decimals=2).tolist()
# Get the number of dimensions
dimensions = len(rounded_embedding)
# Return structured response
return {
"dimensions": dimensions,
"embeddings": [rounded_embedding] # Wrap the embedding inside a list
}
except Exception as e:
# Handle any errors
raise HTTPException(status_code=500, detail=str(e))
# Run the FastAPI app
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860) |