logeswari commited on
Commit
8425d1d
·
1 Parent(s): a1738a7
Files changed (3) hide show
  1. Dockerfile +22 -0
  2. app.py +83 -0
  3. requirements.txt +13 -0
Dockerfile ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9
2
+
3
+ WORKDIR /app
4
+ COPY . /app
5
+
6
+ ENV HF_HOME=/app/.cache
7
+
8
+ RUN mkdir -p /app/.cache/huggingface/hub && \
9
+ chmod -R 777 /app/.cache && \
10
+ chmod -R 777 /app/.cache/huggingface
11
+
12
+
13
+
14
+ RUN pip install --upgrade pip
15
+ RUN pip install --no-cache-dir -r requirements.txt
16
+
17
+ COPY --chown=user ./requirements.txt requirements.txt
18
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
19
+
20
+ EXPOSE 7860
21
+
22
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
app.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from fastapi.responses import FileResponse
3
+ import pandas as pd
4
+ from sklearn.model_selection import train_test_split
5
+ from sentence_transformers import SentenceTransformer
6
+ from sklearn.linear_model import LogisticRegression
7
+ from sklearn.metrics import accuracy_score
8
+ from pydantic import BaseModel
9
+ import numpy as np
10
+ import uvicorn
11
+ import logging
12
+
13
+ # Set up logging
14
+ logging.basicConfig(level=logging.INFO)
15
+ logger = logging.getLogger(__name__)
16
+
17
+ app = FastAPI()
18
+
19
+ # Load and preprocess dataset
20
+ file_name = r"D:/new/sms_process_data_main.xlsx"
21
+ sheet = "Sheet1"
22
+ df = pd.read_excel(file_name, sheet_name=sheet)
23
+
24
+ # Split data
25
+ X_train, X_test, y_train, y_test = train_test_split(
26
+ df['MessageText'], df['label'], test_size=0.2, random_state=42
27
+ )
28
+
29
+ # Load sentence embedding model
30
+ embedding_model = SentenceTransformer('Alibaba-NLP/gte-base-en-v1.5', trust_remote_code=True)
31
+
32
+ # Generate embeddings
33
+ X_train_embeddings = embedding_model.encode(X_train.tolist(), convert_to_tensor=True).cpu().numpy()
34
+ X_test_embeddings = embedding_model.encode(X_test.tolist(), convert_to_tensor=True).cpu().numpy()
35
+
36
+ # Train logistic regression model
37
+ logistic_model = LogisticRegression(max_iter=1000)
38
+ logistic_model.fit(X_train_embeddings, y_train)
39
+
40
+ # Evaluate model
41
+ y_pred = logistic_model.predict(X_test_embeddings)
42
+ accuracy = accuracy_score(y_test, y_pred)
43
+ logger.info(f"Model trained with accuracy: {accuracy:.4f}")
44
+
45
+ # API Input Model
46
+ class MessageInput(BaseModel):
47
+ messages: list[str]
48
+
49
+ # Root endpoint
50
+ @app.get("/")
51
+ def read_root():
52
+ return {"message": "Welcome to the SMS Classification API!"}
53
+
54
+ # Predict endpoint
55
+ @app.post("/predict")
56
+ def predict_sms(data: MessageInput):
57
+ try:
58
+ # Generate embeddings for new messages
59
+ new_embeddings = embedding_model.encode(data.messages, convert_to_tensor=True).cpu().numpy()
60
+
61
+ # Predict labels
62
+ predictions = logistic_model.predict(new_embeddings).tolist()
63
+
64
+
65
+ # Prepare the response with embeddings and dimensions
66
+ response = {
67
+ "dimensions": new_embeddings.shape[1], # Number of dimensions in the embeddings
68
+ "embeddings": new_embeddings.tolist(), # Convert embeddings to a list
69
+ "predictions": predictions # Include predictions
70
+ }
71
+ return response
72
+
73
+ except Exception as e:
74
+ logger.error(f"Error during prediction: {e}")
75
+ raise HTTPException(status_code=500, detail=str(e))
76
+
77
+ # Favicon endpoint (optional)
78
+ @app.get("/favicon.ico")
79
+ def favicon():
80
+ return FileResponse("path/to/favicon.ico")
81
+
82
+ if __name__ == "__main__":
83
+ uvicorn.run(app, host="0.0.0.0", port=8000)
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ annotated-types==0.7.0
2
+ anyio==4.8.0
3
+ fastapi==0.115.8
4
+ idna==3.10
5
+ pydantic==2.10.6
6
+ pydantic_core==2.27.2
7
+ sniffio==1.3.1
8
+ starlette==0.45.3
9
+ typing_extensions==4.12.2
10
+ sentence-transformers==2.2.2
11
+ scikit-learn==1.3.2
12
+ numpy==1.26.4
13
+ pandas==2.1.4