Embedded / back_end /service /train_model.py
velmurugan1122's picture
cousin method
ad471a0
import os
import pickle
import pandas as pd
from sentence_transformers import SentenceTransformer
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
# Define paths
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
DATA_PATH = "data/sms_process_data_main.xlsx"
MODEL_PATH = "models/logistic.pkl"
# Check if the dataset file exists
print(DATA_PATH)
if not os.path.exists(DATA_PATH):
raise FileNotFoundError(f"Dataset file not found at: {DATA_PATH}")
# Load dataset
df = pd.read_excel(DATA_PATH)
# Ensure the dataset has the required columns (adjust as necessary)
if not {'text', 'label'}.issubset(df.columns):
raise ValueError("Dataset must contain 'text' and 'label' columns")
# Load Sentence Transformer model
embedding_model = SentenceTransformer('Alibaba-NLP/gte-base-en-v1.5', trust_remote_code=True)
# Generate embeddings
X = df['text'].apply(lambda x: embedding_model.encode(x).tolist()).tolist()
y = df['label']
# Train/test split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# Train Logistic Regression model
logistic_model = LogisticRegression(max_iter=1000)
logistic_model.fit(X_train, y_train)
# Evaluate the model
y_pred = logistic_model.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f"Model Accuracy: {accuracy:.4f}")
# Save the model
print("Saving model and embeddings...")
with open(MODEL_PATH, 'wb') as f:
pickle.dump(logistic_model, f)
print(f"Logistic model saved to {MODEL_PATH}")