molehh commited on
Commit
58dc04e
·
verified ·
1 Parent(s): c4559d7

changes in main

Browse files
Files changed (1) hide show
  1. main.py +106 -48
main.py CHANGED
@@ -1,48 +1,106 @@
1
- from fastapi import FastAPI, HTTPException
2
- from pydantic import BaseModel
3
- from sentence_transformers import SentenceTransformer
4
- import numpy as np
5
-
6
- # Initialize the FastAPI app
7
- app = FastAPI()
8
-
9
- # Load the pre-trained SentenceTransformer model
10
- model = SentenceTransformer("Alibaba-NLP/gte-base-en-v1.5", trust_remote_code=True)
11
-
12
- # Define the request body schema
13
- class TextInput(BaseModel):
14
- text: str
15
-
16
- # Home route
17
- @app.get("/")
18
- async def home():
19
- return {"message": "welcome to home page"}
20
-
21
- # Define the API endpoint for generating embeddings
22
- @app.post("/embed")
23
- async def generate_embedding(text_input: TextInput):
24
- """
25
- Generate a 768-dimensional embedding for the input text.
26
- Returns the embedding in a structured format with rounded values.
27
- """
28
- try:
29
- # Generate the embedding
30
- embedding = model.encode(text_input.text, convert_to_tensor=True).cpu().numpy()
31
-
32
- # Round embedding values to 2 decimal places
33
- rounded_embedding = np.round(embedding, 2).tolist()
34
-
35
- # Return structured response
36
- return {
37
- "dimensions": len(rounded_embedding),
38
- "embeddings": [rounded_embedding]
39
- }
40
-
41
- except Exception as e:
42
- # Handle any errors
43
- raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
44
-
45
- # Run the FastAPI app
46
- if __name__ == "__main__":
47
- import uvicorn
48
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ from sentence_transformers import SentenceTransformer,util
4
+ from sklearn.model_selection import train_test_split
5
+ from sklearn.linear_model import LogisticRegression
6
+ import uvicorn
7
+ import numpy as np
8
+ import pandas as pd
9
+
10
+
11
+ # Initialize the FastAPI app
12
+ app = FastAPI()
13
+
14
+ # Load the pre-trained SentenceTransformer model
15
+ model = SentenceTransformer("Alibaba-NLP/gte-base-en-v1.5", trust_remote_code=True)
16
+
17
+ # Define the request body schema
18
+ class TextInput(BaseModel):
19
+ text: str
20
+
21
+ # Home route
22
+ @app.get("/")
23
+ async def home():
24
+ return {"message": "welcome to home page"}
25
+
26
+ # Define the API endpoint for generating embeddings
27
+ @app.post("/embed")
28
+ async def generate_embedding(text_input: TextInput):
29
+ """
30
+ Generate a 768-dimensional embedding for the input text.
31
+ Returns the embedding in a structured format with rounded values.
32
+ """
33
+ try:
34
+ # Generate the embedding
35
+ embedding = model.encode(text_input.text, convert_to_tensor=True).cpu().numpy()
36
+
37
+ # Round embedding values to 2 decimal places
38
+ rounded_embedding = np.round(embedding, 2).tolist()
39
+
40
+ # Return structured response
41
+ return {
42
+ "dimensions": len(rounded_embedding),
43
+ "embeddings": [rounded_embedding]
44
+ }
45
+
46
+ except Exception as e:
47
+ # Handle any errors
48
+ raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
49
+
50
+
51
+ # Load pre-trained SentenceTransformer model
52
+ model = SentenceTransformer("Alibaba-NLP/gte-base-en-v1.5", trust_remote_code=True)
53
+
54
+ # Train the Logistic Regression model during app startup
55
+ df = pd.read_excel("sms_process_data_main.xlsx")
56
+ X_train, X_test, y_train, y_test = train_test_split(df["MessageText"], df["label"], test_size=0.2, random_state=42)
57
+ X_train_embeddings = model.encode(X_train.tolist())
58
+
59
+ # Initialize and train the Logistic Regression model
60
+ logreg_model = LogisticRegression(max_iter=100)
61
+ logreg_model.fit(X_train_embeddings, y_train)
62
+
63
+ # Define input schema
64
+ class TextInput(BaseModel):
65
+ text: str
66
+
67
+ @app.post("/predict")
68
+ async def generate_prediction(text_input: TextInput):
69
+ """
70
+ Predict the label for the given text input using the trained model.
71
+ """
72
+ try:
73
+ # Generate embedding for the input text
74
+ new_embedding = model.encode([text_input.text])
75
+
76
+ # Predict the label using the trained Logistic Regression model
77
+ prediction = logreg_model.predict(new_embedding).tolist()[0] # Extract single prediction
78
+
79
+ # Return structured response
80
+ return {
81
+
82
+ "predicted_label": prediction
83
+ }
84
+ except Exception as e:
85
+ # Handle any errors
86
+ raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
87
+
88
+ class SentencesInput(BaseModel):
89
+ sentence1: str
90
+ sentence2: str
91
+ @app.post("/text_to_tensor")
92
+ def text_to_tensor(input: SentencesInput):
93
+ try:
94
+ # Generate embeddings
95
+ embeddings = model.encode([input.sentence1, input.sentence2])
96
+
97
+ # Compute cosine similarity
98
+ cosine_similarity = util.cos_sim(embeddings[0], embeddings[1]).item()
99
+
100
+ return {"cosine_similarity": round(cosine_similarity, 3)}
101
+ except Exception as e:
102
+ raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
103
+
104
+
105
+ if __name__ == "__main__":
106
+ uvicorn.run(app, host="0.0.0.0", port=7860)