tharu22 commited on
Commit
703dd7a
·
1 Parent(s): 18017a6
Files changed (2) hide show
  1. __pycache__/main.cpython-313.pyc +0 -0
  2. main.py +16 -41
__pycache__/main.cpython-313.pyc CHANGED
Binary files a/__pycache__/main.cpython-313.pyc and b/__pycache__/main.cpython-313.pyc differ
 
main.py CHANGED
@@ -1,29 +1,25 @@
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
- from sentence_transformers import SentenceTransformer, util
4
  import numpy as np
5
 
6
  # Initialize the FastAPI app
7
  app = FastAPI()
8
 
9
- # Load the embedding models
10
- embedding_model = SentenceTransformer("Alibaba-NLP/gte-base-en-v1.5", trust_remote_code=True)
11
- similarity_model = SentenceTransformer("all-MiniLM-L6-v2")
12
 
13
- # Define request body schemas
14
  class TextInput(BaseModel):
15
  text: str
16
 
17
- class SimilarityInput(BaseModel):
18
- text1: str
19
- text2: str
20
-
21
  # Home route
22
  @app.get("/")
23
  async def home():
24
- return {"message": "Welcome to the embedding and similarity API. Use /docs to test the endpoints."}
25
 
26
- # Endpoint for generating embeddings
27
  @app.post("/embed")
28
  async def generate_embedding(text_input: TextInput):
29
  """
@@ -31,42 +27,21 @@ async def generate_embedding(text_input: TextInput):
31
  Returns the embedding in a structured format with rounded values.
32
  """
33
  try:
34
- embedding = embedding_model.encode(text_input.text, convert_to_tensor=True).cpu().numpy()
 
 
 
35
  rounded_embedding = np.round(embedding, decimals=2).tolist()
 
 
36
  dimensions = len(rounded_embedding)
37
 
 
38
  return {
39
  "dimensions": dimensions,
40
- "embeddings": rounded_embedding
41
  }
42
  except Exception as e:
 
43
  raise HTTPException(status_code=500, detail=str(e))
44
 
45
- # New endpoint for calculating cosine similarity
46
- @app.post("/similarity")
47
- async def calculate_similarity(similarity_input: SimilarityInput):
48
- """
49
- Calculate cosine similarity between two text inputs.
50
- """
51
- try:
52
- # Compute embeddings
53
- embeddings1 = similarity_model.encode(similarity_input.text1, convert_to_tensor=True)
54
- embeddings2 = similarity_model.encode(similarity_input.text2, convert_to_tensor=True)
55
-
56
- # Compute cosine similarity
57
- cosine_similarity = util.cos_sim(embeddings1, embeddings2).item()
58
-
59
- return {
60
- "text1": similarity_input.text1,
61
- "text2": similarity_input.text2,
62
- "cosine_similarity": round(cosine_similarity, 4)
63
- }
64
- except Exception as e:
65
- raise HTTPException(status_code=500, detail=str(e))
66
-
67
- # Run the FastAPI app
68
- if __name__ == "__main__":
69
- import uvicorn
70
- uvicorn.run(app, host="0.0.0.0", port=7860)
71
-
72
-
 
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
+ from sentence_transformers import SentenceTransformer
4
  import numpy as np
5
 
6
  # Initialize the FastAPI app
7
  app = FastAPI()
8
 
9
+ # Load the pre-trained SentenceTransformer model from Hugging Face
10
+ #model = SentenceTransformer("//huggingface.co/spaces/Kabila22/Kabilan_embedding_1", trust_remote_code=True)
11
+ model = SentenceTransformer("Alibaba-NLP/gte-base-en-v1.5", trust_remote_code=True)
12
 
13
+ # Define the request body schema
14
  class TextInput(BaseModel):
15
  text: str
16
 
 
 
 
 
17
  # Home route
18
  @app.get("/")
19
  async def home():
20
+ return {"message": "Welcome to embedding SMS API, use /docs to post SMS text and get dimensions"}
21
 
22
+ # Define the API endpoint
23
  @app.post("/embed")
24
  async def generate_embedding(text_input: TextInput):
25
  """
 
27
  Returns the embedding in a structured format with rounded values.
28
  """
29
  try:
30
+ # Generate the embedding
31
+ embedding = model.encode(text_input.text, convert_to_tensor=True).cpu().numpy()
32
+
33
+ # Round embedding values to 2 decimal places
34
  rounded_embedding = np.round(embedding, decimals=2).tolist()
35
+
36
+ # Get the number of dimensions
37
  dimensions = len(rounded_embedding)
38
 
39
+ # Return structured response
40
  return {
41
  "dimensions": dimensions,
42
+ "embeddings": [rounded_embedding] # Wrap the embedding inside a list
43
  }
44
  except Exception as e:
45
+ # Handle any errors
46
  raise HTTPException(status_code=500, detail=str(e))
47