LOG_REG / predict.py
subbunanepalli's picture
Create predict.py
e3c9101 verified
raw
history blame contribute delete
835 Bytes
import pandas as pd
from config import MODEL_PATH, TFIDF_PATH, TEXT_COLUMN, LABEL_COLUMNS
from utils import load_model_and_vectorizer
from schemas import TransactionData
def predict_labels(input_record: TransactionData):
# Load the model and vectorizer
model, vectorizer = load_model_and_vectorizer(MODEL_PATH, TFIDF_PATH)
# Convert input Pydantic model to DataFrame
input_data = pd.DataFrame([input_record.dict()])
# Prepare text input by selecting relevant fields
sanction_context = input_data[TEXT_COLUMN].iloc[0]
# Vectorize the input text
X_vec = vectorizer.transform([sanction_context])
# Predict labels
y_pred = model.predict(X_vec)
# Format predictions
predictions = {
LABEL_COLUMNS[i]: y_pred[0][i] for i in range(len(LABEL_COLUMNS))
}
return predictions