Spaces:
Runtime error
Runtime error
| import os | |
| from fastapi import FastAPI | |
| from typing import Union, List, Dict, Tuple, Optional | |
| from pydantic import BaseModel, Field | |
| from angle_emb import AnglE | |
| class EmbeddingInput(BaseModel): | |
| input: Union[List[str], Tuple[str], List[Dict], str] = Field(..., description="The input to be encoded") | |
| model: Optional[str] = None | |
| encoding_format: Optional[str] = 'float' | |
| dimensions: Optional[int] = None | |
| user: Optional[str] = None | |
| app = FastAPI() | |
| # Get the model name and path from the environment variables | |
| model_name = os.getenv('MODEL_NAME', default='WhereIsAI/UAE-Large-V1') | |
| model_path = os.getenv('MODEL_PATH', default='models/WhereIsAI/UAE-Large-V1') | |
| # Load the model | |
| try: | |
| angle_model = AnglE.from_pretrained(model_path, pooling_strategy='cls').to('cpu') | |
| except Exception as e: | |
| print(f"Failed to load model from path {model_path}. Error: {str(e)}") | |
| def read_root(): | |
| return { | |
| "model_name": model_name, | |
| "model_path": model_path, | |
| "message": "Model is up and running", | |
| "route_info": { | |
| "/": "Returns the model info", | |
| "/health": "Returns the health status of the application", | |
| "/v1/embeddings": 'POST route to get embeddings. Usage: curl -H "Content-Type: application/json" -d \'{ "input": "Your text string goes here" }\' http://localhost:8080/v1/embeddings' | |
| } | |
| } | |
| def health_check(): | |
| return {"health": "ok"} | |
| def get_embeddings(embedding_input: EmbeddingInput): | |
| # # Check if the input is an empty string | |
| # if not embedding_input.input.strip(): | |
| # return { | |
| # "object": "list", | |
| # "data": [], | |
| # "model": model_name, | |
| # "usage": {"prompt_tokens": 0, "total_tokens": 0}, | |
| # } | |
| # Encode the input text using the model | |
| embeddings = angle_model.encode(embedding_input.input, embedding_size=embedding_input.dimensions) | |
| # Create a response format compatible with OpenAI's API | |
| response = { | |
| "object": "list", | |
| "data": [ | |
| {"object": "embedding", "index": i, "embedding": emb.tolist()} | |
| for i, emb in enumerate(embeddings) | |
| ], | |
| "model": model_name, | |
| "usage": {"prompt_tokens": len(embedding_input.input), "total_tokens": len(embedding_input.input)}, | |
| } | |
| return response | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8080) | |