tharu22 commited on
Commit
03e26b5
·
1 Parent(s): 602ab37
Files changed (1) hide show
  1. main.py +59 -21
main.py CHANGED
@@ -1,34 +1,72 @@
1
- from fastapi import FastAPI
2
  from pydantic import BaseModel
3
- from sentence_transformers import SentenceTransformer
4
  import numpy as np
5
 
6
- import os
7
- os.environ["HF_HOME"] = "/tmp/huggingface"
8
- # Initialize FastAPI app
9
  app = FastAPI()
10
 
11
- from sentence_transformers import SentenceTransformer
 
 
12
 
13
- embedding_model = SentenceTransformer("Alibaba-NLP/gte-base-en-v1.5")
 
 
14
 
 
 
 
15
 
16
- # Define request body structure
17
- class TextRequest(BaseModel):
18
- text: str
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- # Define response structure
21
- class EmbeddingResponse(BaseModel):
22
- dimensions: int
23
- embedding: list[float]
24
 
25
- # Create API endpoint
26
- @app.post("/get_embedding", response_model=EmbeddingResponse)
27
- async def get_embedding(request: TextRequest):
28
- # Generate embedding
29
- embedding = embedding_model.encode([request.text])[0] # Extract first item
 
 
30
 
31
- # Convert to list and return response
32
- return {"dimensions": len(embedding), "embedding": embedding.tolist()}
 
 
33
 
34
 
 
1
+ rom fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
+ from sentence_transformers import SentenceTransformer, util
4
  import numpy as np
5
 
6
+ # Initialize the FastAPI app
 
 
7
  app = FastAPI()
8
 
9
+ # Load the embedding models
10
+ embedding_model = SentenceTransformer("Alibaba-NLP/gte-base-en-v1.5", trust_remote_code=True)
11
+ similarity_model = SentenceTransformer("all-MiniLM-L6-v2")
12
 
13
+ # Define request body schemas
14
+ class TextInput(BaseModel):
15
+ text: str
16
 
17
+ class SimilarityInput(BaseModel):
18
+ text1: str
19
+ text2: str
20
 
21
+ # Home route
22
+ @app.get("/")
23
+ async def home():
24
+ return {"message": "Welcome to the embedding and similarity API. Use /docs to test the endpoints."}
25
+
26
+ # Endpoint for generating embeddings
27
+ @app.post("/embed")
28
+ async def generate_embedding(text_input: TextInput):
29
+ """
30
+ Generate a 768-dimensional embedding for the input text.
31
+ Returns the embedding in a structured format with rounded values.
32
+ """
33
+ try:
34
+ embedding = embedding_model.encode(text_input.text, convert_to_tensor=True).cpu().numpy()
35
+ rounded_embedding = np.round(embedding, decimals=2).tolist()
36
+ dimensions = len(rounded_embedding)
37
+
38
+ return {
39
+ "dimensions": dimensions,
40
+ "embeddings": rounded_embedding
41
+ }
42
+ except Exception as e:
43
+ raise HTTPException(status_code=500, detail=str(e))
44
+
45
+ # New endpoint for calculating cosine similarity
46
+ @app.post("/similarity")
47
+ async def calculate_similarity(similarity_input: SimilarityInput):
48
+ """
49
+ Calculate cosine similarity between two text inputs.
50
+ """
51
+ try:
52
+ # Compute embeddings
53
+ embeddings1 = similarity_model.encode(similarity_input.text1, convert_to_tensor=True)
54
+ embeddings2 = similarity_model.encode(similarity_input.text2, convert_to_tensor=True)
55
 
56
+ # Compute cosine similarity
57
+ cosine_similarity = util.cos_sim(embeddings1, embeddings2).item()
 
 
58
 
59
+ return {
60
+ "text1": similarity_input.text1,
61
+ "text2": similarity_input.text2,
62
+ "cosine_similarity": round(cosine_similarity, 4)
63
+ }
64
+ except Exception as e:
65
+ raise HTTPException(status_code=500, detail=str(e))
66
 
67
+ # Run the FastAPI app
68
+ if __name__ == "__main__":
69
+ import uvicorn
70
+ uvicorn.run(app, host="0.0.0.0", port=7860)
71
 
72