Commit
·
f1017a3
1
Parent(s):
93479fa
test files added
Browse files- Huggin_face_test/fsa.py +304 -0
- Huggin_face_test/helpers.py +246 -0
Huggin_face_test/fsa.py
ADDED
|
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Importing libraries
|
| 2 |
+
from threading import Thread
|
| 3 |
+
from flask import Blueprint, jsonify, request
|
| 4 |
+
from flask_cors import CORS
|
| 5 |
+
import sys
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
# Importing process pool executor
|
| 10 |
+
from concurrent.futures import ProcessPoolExecutor
|
| 11 |
+
|
| 12 |
+
# Fasttext for model handling
|
| 13 |
+
import fasttext
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# Setting absolute path
|
| 17 |
+
sys.path.insert(0, os.path.abspath("."))
|
| 18 |
+
|
| 19 |
+
from app.config import Config
|
| 20 |
+
from app.helpers import *
|
| 21 |
+
from app.db.models import Tasks
|
| 22 |
+
from app.database import db
|
| 23 |
+
from app.threads.process_fsa_v2 import process_fsa_categories_v2
|
| 24 |
+
# from app.threads.process_fsa_v2 import test_function
|
| 25 |
+
|
| 26 |
+
# Create a Blueprint of classification
|
| 27 |
+
fsa = Blueprint("fsa_v2", __name__, url_prefix="/api/v2/fsa")
|
| 28 |
+
|
| 29 |
+
# Enabling CORS for the blueprint
|
| 30 |
+
CORS(
|
| 31 |
+
fsa,
|
| 32 |
+
supports_credentials=True
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# Thread class to run the bacth processing in the thread
|
| 37 |
+
class FSAThread_V2(Thread):
|
| 38 |
+
def __init__(self, data={}) -> None:
|
| 39 |
+
Thread.__init__(self)
|
| 40 |
+
self.data = data
|
| 41 |
+
# Run function of the thread
|
| 42 |
+
def run(self) -> None:
|
| 43 |
+
process_fsa_categories_v2(self.data)
|
| 44 |
+
|
| 45 |
+
# Creating a process pool executor
|
| 46 |
+
# Set maximum processes
|
| 47 |
+
max_processes = 4
|
| 48 |
+
process_executor = ProcessPoolExecutor(max_workers=max_processes)
|
| 49 |
+
|
| 50 |
+
# Update the database
|
| 51 |
+
def update_db(table_idx, remarks=None):
|
| 52 |
+
from app.api import app
|
| 53 |
+
|
| 54 |
+
with app.app_context():
|
| 55 |
+
Tasks.update_by_id(table_idx, remarks)
|
| 56 |
+
db.session.close()
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# Prediction for single product
|
| 60 |
+
@fsa.route("/single-product", methods=["POST"])
|
| 61 |
+
def predict_categories():
|
| 62 |
+
|
| 63 |
+
# Get the request
|
| 64 |
+
body = request.json
|
| 65 |
+
|
| 66 |
+
# If there is no body in the request send error message
|
| 67 |
+
if not body:
|
| 68 |
+
return jsonify({"message": "Cannot decode JSON from the body"}), 422
|
| 69 |
+
|
| 70 |
+
# Get the product name from the JSON
|
| 71 |
+
product_name = body.get("product_name")
|
| 72 |
+
|
| 73 |
+
# Check whether product name is missing
|
| 74 |
+
if not product_name:
|
| 75 |
+
return jsonify({"message": "Product name is missing"}), 422
|
| 76 |
+
|
| 77 |
+
# Preprocessing product names for input
|
| 78 |
+
product_name = preprocess(product_name)
|
| 79 |
+
|
| 80 |
+
# Prediction
|
| 81 |
+
# Logging processing
|
| 82 |
+
Logger.info(message="Processing FSA categorical data for " + product_name)
|
| 83 |
+
|
| 84 |
+
# Loading L0 model to model
|
| 85 |
+
try:
|
| 86 |
+
model = fasttext.load_model('app/models/L0/L0_model.bin')
|
| 87 |
+
except:
|
| 88 |
+
return jsonify({"message": "Can't load the L0 model"}), 500
|
| 89 |
+
|
| 90 |
+
#Getting L0 prediction and accuracy
|
| 91 |
+
L0_label,L0_accuracy = get_label_and_accuracy(model,product_name)
|
| 92 |
+
L0_return_label,L0_return_score,L0_label_status = get_return_labels(L0_label,L0_accuracy,0.95)
|
| 93 |
+
print("L0",L0_label,L0_accuracy)
|
| 94 |
+
|
| 95 |
+
if not L0_label:
|
| 96 |
+
return jsonify({"message": "Error predicting L0 Category"}), 500
|
| 97 |
+
|
| 98 |
+
#Loading L1 model to model
|
| 99 |
+
try:
|
| 100 |
+
model = fasttext.load_model('app/models/L1/L1_model.bin')
|
| 101 |
+
except:
|
| 102 |
+
return jsonify({"message": "Can't load the L1 model"}), 500
|
| 103 |
+
|
| 104 |
+
#Getting L1 prediction and accuracy
|
| 105 |
+
L1_label,L1_accuracy = get_label_and_accuracy(model,L0_label +" " + product_name)
|
| 106 |
+
L1_return_label,L1_return_score,L1_label_status = get_return_labels(L1_label,L1_accuracy,0.95)
|
| 107 |
+
print("L1",L1_label,L1_accuracy)
|
| 108 |
+
|
| 109 |
+
if not L1_label:
|
| 110 |
+
return jsonify({"message": "Error predicting L1 Category"}), 500
|
| 111 |
+
|
| 112 |
+
#Loading L2 model to model
|
| 113 |
+
try:
|
| 114 |
+
model = fasttext.load_model('app/models/L2/L2_model.bin')
|
| 115 |
+
except:
|
| 116 |
+
return jsonify({"message": "Can't load the L2 model"}), 500
|
| 117 |
+
|
| 118 |
+
#Getting L2 prediction and accuracy
|
| 119 |
+
L2_label,L2_accuracy = get_label_and_accuracy(model,L1_label+" "+product_name)
|
| 120 |
+
L2_return_label,L2_return_score,L2_label_status = get_return_labels(L2_label,L2_accuracy,0.95)
|
| 121 |
+
print("L2",L2_label,L2_accuracy)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
if not L2_label:
|
| 125 |
+
return jsonify({"message": "Error predicting L2 Category"}), 500
|
| 126 |
+
|
| 127 |
+
#Loading L3 model to model
|
| 128 |
+
try:
|
| 129 |
+
model = fasttext.load_model('app/models/L3/L3_model.bin')
|
| 130 |
+
except:
|
| 131 |
+
return jsonify({"message": "Can't load the L3 model"}), 500
|
| 132 |
+
#Getting L3 prediction and accuracy
|
| 133 |
+
L3_label,L3_accuracy = get_label_and_accuracy(model,L2_label+" "+product_name)
|
| 134 |
+
L3_return_label,L3_return_score,L3_label_status = get_return_labels(L3_label,L3_accuracy,0.95)
|
| 135 |
+
print("L3",L3_label,L3_accuracy)
|
| 136 |
+
|
| 137 |
+
if not L3_label:
|
| 138 |
+
return jsonify({"message": "Error predicting L3 Category"}), 500
|
| 139 |
+
|
| 140 |
+
if L0_label == "administrative":
|
| 141 |
+
try:
|
| 142 |
+
model = fasttext.load_model('app/models/L4/administrative/L4_Admin_model.bin')
|
| 143 |
+
except:
|
| 144 |
+
return jsonify({"message": "Can't load the L4 (Administrative) model"}), 500
|
| 145 |
+
#Getting L4 prediction and accuracy
|
| 146 |
+
L4_label,L4_accuracy = get_label_and_accuracy(model,(L3_label+ " " +product_name))
|
| 147 |
+
L4_return_label,L4_return_score,L4_label_status = get_return_labels(L4_label,L4_accuracy,0.75)
|
| 148 |
+
print("L4",L4_label,L4_accuracy)
|
| 149 |
+
|
| 150 |
+
# L0 = Beverage
|
| 151 |
+
elif L0_label == "beverage":
|
| 152 |
+
try:
|
| 153 |
+
model = fasttext.load_model('app/models/L4/beverage/L4_beverage_model.bin')
|
| 154 |
+
except:
|
| 155 |
+
return jsonify({"message": "Can't load the L4 (Beverage) model"}), 500
|
| 156 |
+
#Getting L4 prediction and accuracy
|
| 157 |
+
L4_label,L4_accuracy = get_label_and_accuracy(model,(L3_label+" "+product_name))
|
| 158 |
+
L4_return_score = None
|
| 159 |
+
L4_return_label,L4_return_score,L4_label_status = get_return_labels(L4_label,L4_accuracy,0.66)
|
| 160 |
+
print("L4",L4_label,L4_accuracy)
|
| 161 |
+
|
| 162 |
+
# L0 = Food
|
| 163 |
+
elif L0_label == "food":
|
| 164 |
+
try:
|
| 165 |
+
model = fasttext.load_model('app/models/L4/food/L4_food_model.bin')
|
| 166 |
+
except:
|
| 167 |
+
return jsonify({"message": "Can't load the L4 (Food) model"}), 500
|
| 168 |
+
#Getting L4 prediction and accuracy
|
| 169 |
+
L4_label,L4_accuracy = get_label_and_accuracy(model,(L3_label+" "+product_name))
|
| 170 |
+
L4_return_label,L4_return_score,L4_label_status = get_return_labels(L4_label,L4_accuracy,0.85)
|
| 171 |
+
print("L4",L4_label,L4_accuracy)
|
| 172 |
+
|
| 173 |
+
# L0 = Operationals
|
| 174 |
+
elif L0_label == "operationals":
|
| 175 |
+
try:
|
| 176 |
+
model = fasttext.load_model('app/models/L4/operationals/L4_operationals_model.bin')
|
| 177 |
+
except:
|
| 178 |
+
return jsonify({"message": "Can't load the L4 (Operationals) model"}), 500
|
| 179 |
+
#Getting L4 prediction and accuracy
|
| 180 |
+
L4_label,L4_accuracy = get_label_and_accuracy(model,(L3_label+" "+product_name))
|
| 181 |
+
L4_return_label,L4_return_score,L4_label_status = get_return_labels(L4_label,L4_accuracy,0.8)
|
| 182 |
+
print("L4",L4_label,L4_accuracy)
|
| 183 |
+
|
| 184 |
+
# Error prediction on L4 Category (Can't happen)
|
| 185 |
+
else:
|
| 186 |
+
return jsonify({"message": "Error prediction of L4 Category"}), 422
|
| 187 |
+
|
| 188 |
+
if not L4_label:
|
| 189 |
+
return jsonify({"message": "Error predicting L4 Category"}), 422
|
| 190 |
+
|
| 191 |
+
# Logging the task
|
| 192 |
+
Logger.info(message="Done processing FSA categorical data for" + product_name)
|
| 193 |
+
|
| 194 |
+
# Rreturning the result as JSON
|
| 195 |
+
|
| 196 |
+
return jsonify({
|
| 197 |
+
"classification_results": {
|
| 198 |
+
"l0": L0_return_label,
|
| 199 |
+
"l1": L1_return_label,
|
| 200 |
+
"l2": L2_return_label,
|
| 201 |
+
"l3": L3_return_label,
|
| 202 |
+
"l4": L4_return_label
|
| 203 |
+
},
|
| 204 |
+
"scores": {
|
| 205 |
+
"l0": L0_return_score,
|
| 206 |
+
"l1": L1_return_score,
|
| 207 |
+
"l2": L2_return_score,
|
| 208 |
+
"l3": L3_return_score,
|
| 209 |
+
"l4": L4_return_score
|
| 210 |
+
},
|
| 211 |
+
"remarks":{
|
| 212 |
+
"l0": L0_label_status,
|
| 213 |
+
"l1": L1_label_status,
|
| 214 |
+
"l2": L2_label_status,
|
| 215 |
+
"l3": L3_label_status,
|
| 216 |
+
"l4": L4_label_status
|
| 217 |
+
},
|
| 218 |
+
"all_classification_results": {
|
| 219 |
+
"L0": L0_label,
|
| 220 |
+
"L1": L1_label,
|
| 221 |
+
"L2": L2_label,
|
| 222 |
+
"L3": L3_label,
|
| 223 |
+
"L4": L4_label
|
| 224 |
+
},
|
| 225 |
+
"all_scores": {
|
| 226 |
+
"L0": L0_accuracy,
|
| 227 |
+
"L1": L1_accuracy,
|
| 228 |
+
"L2": L2_accuracy,
|
| 229 |
+
"L3": L3_accuracy,
|
| 230 |
+
"L4": L4_accuracy
|
| 231 |
+
}
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
}), 200
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
# Batch processing
|
| 241 |
+
@fsa.route("/process-csv", methods=["POST"])
|
| 242 |
+
def process_csv():
|
| 243 |
+
|
| 244 |
+
# Get the body of the json
|
| 245 |
+
body = request.json
|
| 246 |
+
|
| 247 |
+
# Error passing for missing body
|
| 248 |
+
if not body:
|
| 249 |
+
return jsonify({"message": "Cannot decode JSON from the body"}), 422
|
| 250 |
+
|
| 251 |
+
# It is assumed that uploaded file name in the file_name JSON field
|
| 252 |
+
file_name = body.get("uploaded_file_name")
|
| 253 |
+
|
| 254 |
+
# Original file name
|
| 255 |
+
original_file_name = body.get("original_file_name") or file_name
|
| 256 |
+
|
| 257 |
+
# Missing file name
|
| 258 |
+
if not file_name:
|
| 259 |
+
return jsonify({"message": "File name is missing"}), 422
|
| 260 |
+
|
| 261 |
+
files = [{"name": f"fsa_input_{file_name}", "path": f"FSA Categorization/input/{file_name}"}]
|
| 262 |
+
|
| 263 |
+
# Download files from S3 bucket of AWS
|
| 264 |
+
# File is downloaded to th 'app/constants/{file}'
|
| 265 |
+
for file in files:
|
| 266 |
+
download_status = download_file_from_s3(
|
| 267 |
+
file_name=file["name"], file_path=file["path"]
|
| 268 |
+
)
|
| 269 |
+
if isinstance(download_status, botocore.exceptions.ClientError):
|
| 270 |
+
return (
|
| 271 |
+
jsonify({"message": f"Error downloading {file} from s3"}),
|
| 272 |
+
422,
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
# Get the dataframe of the csv to check whether "ProdName" column is available
|
| 277 |
+
df = read_files(file_name=file_name)
|
| 278 |
+
|
| 279 |
+
# Check for product_names in columns
|
| 280 |
+
if "product_name" not in df.columns:
|
| 281 |
+
remove_files(f"fsa_input_{file_name}")
|
| 282 |
+
return jsonify({"message": "Product name column is missing from the CSV"}), 422
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
# Create a task
|
| 286 |
+
created_task = Tasks.create(file_name=file_name, original_file_name=original_file_name)
|
| 287 |
+
|
| 288 |
+
# Create a json object of data to pass the process
|
| 289 |
+
data = {
|
| 290 |
+
"file_name": file_name,
|
| 291 |
+
"table_idx": created_task.id,
|
| 292 |
+
"update_db": update_db
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
db.session.close()
|
| 296 |
+
# Add the process to process pool executor
|
| 297 |
+
result_future = process_executor.submit(process_fsa_categories_v2, (data))
|
| 298 |
+
|
| 299 |
+
# Creating a thread with data
|
| 300 |
+
# thread = FSAThread_V2(data=data)
|
| 301 |
+
# thread.start()
|
| 302 |
+
|
| 303 |
+
# Testing route
|
| 304 |
+
return jsonify({"message": f"{file_name} - File processing starting"}), 200
|
Huggin_face_test/helpers.py
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import boto3
|
| 4 |
+
import botocore
|
| 5 |
+
import re
|
| 6 |
+
import pandas as pd
|
| 7 |
+
from nltk.corpus import stopwords
|
| 8 |
+
import warnings
|
| 9 |
+
|
| 10 |
+
warnings.filterwarnings("ignore")
|
| 11 |
+
|
| 12 |
+
from app.logger import Logger
|
| 13 |
+
|
| 14 |
+
sys.path.insert(0, os.path.abspath("."))
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def read_files(
|
| 18 |
+
file_name, sort_by=None, drop_duplicates=None, drop_na=None, encoding=None
|
| 19 |
+
):
|
| 20 |
+
df = pd.read_csv(
|
| 21 |
+
os.path.join("app/constants", file_name), low_memory=False, encoding=encoding
|
| 22 |
+
)
|
| 23 |
+
if sort_by:
|
| 24 |
+
df = df.sort_values(by=[sort_by])
|
| 25 |
+
if drop_duplicates:
|
| 26 |
+
print("Removing duplicates in ProdName..")
|
| 27 |
+
print("df rows before removing duplicates = " + str(df.shape[0]))
|
| 28 |
+
df.drop_duplicates(subset=drop_duplicates, keep="first", inplace=True)
|
| 29 |
+
print("df rows after removing duplicates = " + str(df.shape[0]))
|
| 30 |
+
if drop_na:
|
| 31 |
+
print("Removing rows with null values..")
|
| 32 |
+
print("df rows before removing nan values = " + str(df.shape[0]))
|
| 33 |
+
df = df.dropna(subset=drop_na)
|
| 34 |
+
print("df rows after removing nan values = " + str(df.shape[0]))
|
| 35 |
+
df = df.reset_index(drop=True)
|
| 36 |
+
return df
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def check_file_already_downloaded(file_name):
|
| 40 |
+
files = os.listdir("app/constants")
|
| 41 |
+
if file_name in files:
|
| 42 |
+
return True
|
| 43 |
+
return False
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def download_file_from_s3(
|
| 47 |
+
file_name, bucket_name="sku-matching-ai-ml", skip_check=False, file_path=None
|
| 48 |
+
):
|
| 49 |
+
if check_file_already_downloaded(file_name) and not skip_check:
|
| 50 |
+
return file_name
|
| 51 |
+
else:
|
| 52 |
+
print("STARTING DOWNLOADING: ", file_name)
|
| 53 |
+
if not file_path:
|
| 54 |
+
file_path = file_name
|
| 55 |
+
s3 = boto3.client("s3")
|
| 56 |
+
try:
|
| 57 |
+
s3.download_file(
|
| 58 |
+
Bucket=bucket_name, Key=file_path, Filename=f"app/constants/{file_name}"
|
| 59 |
+
)
|
| 60 |
+
print("DOWNLOADING FINISHED")
|
| 61 |
+
return file_name
|
| 62 |
+
# pylint: disable=invalid-name
|
| 63 |
+
except botocore.exceptions.ClientError as e:
|
| 64 |
+
Logger().exception(
|
| 65 |
+
message=f"Unable to download file: {file_name}",
|
| 66 |
+
)
|
| 67 |
+
return e
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def upload_files_to_s3(file_path, upload_path, bucket_name="sku-matching-ai-ml"):
|
| 71 |
+
print("STARTING UPLOADING")
|
| 72 |
+
s3 = boto3.client("s3")
|
| 73 |
+
try:
|
| 74 |
+
s3.upload_file(file_path, bucket_name, upload_path)
|
| 75 |
+
except botocore.exceptions.ClientError as e:
|
| 76 |
+
Logger().exception(
|
| 77 |
+
message=f"Unable to uplaod file",
|
| 78 |
+
)
|
| 79 |
+
return e
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def clean(string):
|
| 83 |
+
raw_text = re.sub("[^a-zA-Z]+", " ", string)
|
| 84 |
+
words = raw_text.lower().split()
|
| 85 |
+
stops = set(stopwords.words("english"))
|
| 86 |
+
meaningful_words = [
|
| 87 |
+
word for word in words if ((not word in stops) and (len(word) >= 3))
|
| 88 |
+
]
|
| 89 |
+
string = " ".join(meaningful_words)
|
| 90 |
+
return string
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def close_open_brackets(input_str):
|
| 94 |
+
opening_brackets = ["(", "[", "{"]
|
| 95 |
+
closing_brackets = [")", "]", "}"]
|
| 96 |
+
stack = []
|
| 97 |
+
|
| 98 |
+
for char in input_str:
|
| 99 |
+
if char in opening_brackets:
|
| 100 |
+
stack.append(char)
|
| 101 |
+
elif char in closing_brackets:
|
| 102 |
+
if len(stack) > 0:
|
| 103 |
+
opening_bracket = stack.pop()
|
| 104 |
+
if opening_brackets.index(opening_bracket) != closing_brackets.index(
|
| 105 |
+
char
|
| 106 |
+
):
|
| 107 |
+
stack.append(opening_bracket)
|
| 108 |
+
stack.append(char)
|
| 109 |
+
else:
|
| 110 |
+
input_str = input_str.replace(char, "")
|
| 111 |
+
|
| 112 |
+
while len(stack) > 0:
|
| 113 |
+
opening_bracket = stack.pop()
|
| 114 |
+
closing_bracket = closing_brackets[opening_brackets.index(opening_bracket)]
|
| 115 |
+
input_str += closing_bracket
|
| 116 |
+
|
| 117 |
+
return input_str
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def iterative_filtering(
|
| 121 |
+
df,
|
| 122 |
+
product,
|
| 123 |
+
column_name,
|
| 124 |
+
skip_clean=False,
|
| 125 |
+
consider_starts_with=True,
|
| 126 |
+
regex=False,
|
| 127 |
+
close_brackets=False,
|
| 128 |
+
):
|
| 129 |
+
if not skip_clean:
|
| 130 |
+
product = clean(product)
|
| 131 |
+
else:
|
| 132 |
+
product = product.lower()
|
| 133 |
+
words = product.split()
|
| 134 |
+
new_df = df
|
| 135 |
+
index = 0
|
| 136 |
+
out_df = new_df
|
| 137 |
+
|
| 138 |
+
while new_df.shape[0] > 0 and index < len(words):
|
| 139 |
+
out_df = new_df
|
| 140 |
+
new_df = df_filtering_by_word(
|
| 141 |
+
new_df,
|
| 142 |
+
words[index],
|
| 143 |
+
column_name,
|
| 144 |
+
consider_starts_with,
|
| 145 |
+
regex,
|
| 146 |
+
close_brackets,
|
| 147 |
+
)
|
| 148 |
+
if new_df.shape[0] > 0:
|
| 149 |
+
out_df = new_df
|
| 150 |
+
new_df[column_name] = new_df[column_name].str.replace(words[index] + " ", "")
|
| 151 |
+
index = index + 1
|
| 152 |
+
out_df = out_df.reset_index(drop=True)
|
| 153 |
+
return out_df
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def df_filtering_by_word(
|
| 157 |
+
df, word, column_name, consider_starts_with=True, regex=False, close_brackets=False
|
| 158 |
+
):
|
| 159 |
+
try:
|
| 160 |
+
if close_brackets:
|
| 161 |
+
word = close_open_brackets(word)
|
| 162 |
+
|
| 163 |
+
if consider_starts_with:
|
| 164 |
+
filtered_df = df[df[column_name].str.startswith(word)]
|
| 165 |
+
if filtered_df.shape[0] == 0:
|
| 166 |
+
filtered_df = df[df[column_name].str.contains(word)]
|
| 167 |
+
else:
|
| 168 |
+
if regex:
|
| 169 |
+
filtered_df = df[
|
| 170 |
+
df[column_name].str.contains(rf"\b({word})\b", case=False)
|
| 171 |
+
]
|
| 172 |
+
else:
|
| 173 |
+
filtered_df = df[df[column_name].str.contains(word)]
|
| 174 |
+
if filtered_df.shape[0] == 0:
|
| 175 |
+
filtered_df = df
|
| 176 |
+
|
| 177 |
+
return filtered_df
|
| 178 |
+
except Exception as e:
|
| 179 |
+
return df_filtering_by_word(df, clean(word), consider_starts_with, regex)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def remove_files(file_name):
|
| 183 |
+
if os.path.exists(f"app/constants/{file_name}"):
|
| 184 |
+
os.remove(f"app/constants/{file_name}")
|
| 185 |
+
|
| 186 |
+
def get_top_mrf_product(mrf_product_attributes_list, dp_product_attributes, sequence_scores, default_attr_key_list):
|
| 187 |
+
scores = []
|
| 188 |
+
for id, each_mrf_prod_attr in enumerate(mrf_product_attributes_list):
|
| 189 |
+
score = sequence_scores[id]
|
| 190 |
+
for key in default_attr_key_list:
|
| 191 |
+
if key in dp_product_attributes and key in each_mrf_prod_attr:
|
| 192 |
+
if pd.notna(dp_product_attributes[key]) and pd.notna(each_mrf_prod_attr[key]):
|
| 193 |
+
if str(dp_product_attributes[key]).lower() == str(each_mrf_prod_attr[key]).lower():
|
| 194 |
+
score += 5
|
| 195 |
+
scores.append(score)
|
| 196 |
+
|
| 197 |
+
max_index = scores.index(max(scores))
|
| 198 |
+
return max_index, max(scores)
|
| 199 |
+
|
| 200 |
+
# Helper files required for FSA V2
|
| 201 |
+
# Preprocessing Function
|
| 202 |
+
'''
|
| 203 |
+
This Function is using for preprocessing the input product names
|
| 204 |
+
'''
|
| 205 |
+
def preprocess(text):
|
| 206 |
+
text = re.sub(r'&', 'and', text)
|
| 207 |
+
text = re.sub(r'[^\w\s]',' ', text)
|
| 208 |
+
text = re.sub(' +', ' ', text)
|
| 209 |
+
return text.strip().lower()
|
| 210 |
+
|
| 211 |
+
# Function to preprocess labels from the previous prediction
|
| 212 |
+
def label_processing(label):
|
| 213 |
+
label = re.sub('__label__', '', label)
|
| 214 |
+
label = re.sub('_', ' ', label)
|
| 215 |
+
label = re.sub(' +', ' ', label)
|
| 216 |
+
return label.strip().lower()
|
| 217 |
+
|
| 218 |
+
def get_return_labels(label,accuracy,threshold):
|
| 219 |
+
if accuracy >= threshold:
|
| 220 |
+
return_label = label
|
| 221 |
+
return_score = accuracy
|
| 222 |
+
label_status = f"Classified - Above threshold {threshold}"
|
| 223 |
+
else:
|
| 224 |
+
return_label = None
|
| 225 |
+
return_score = None
|
| 226 |
+
label_status = f"Unclassfied - Below threshold {threshold}"
|
| 227 |
+
return return_label,return_score,label_status
|
| 228 |
+
|
| 229 |
+
#Function to get the product label and accuracy
|
| 230 |
+
def get_label_and_accuracy(model,product_name):
|
| 231 |
+
prediction = model.predict(product_name)
|
| 232 |
+
label = prediction[0][0]
|
| 233 |
+
label = label_processing(label)
|
| 234 |
+
accuracy = round(prediction[1][0],3)
|
| 235 |
+
|
| 236 |
+
return label,accuracy
|
| 237 |
+
|
| 238 |
+
# Function for remove new line in product name
|
| 239 |
+
'''
|
| 240 |
+
Some products may contain new line characters in middle of product names.
|
| 241 |
+
This may occur because of preprocessing. It can lead to result \n in middle of the
|
| 242 |
+
product names.
|
| 243 |
+
'''
|
| 244 |
+
def remove_new_lines(text):
|
| 245 |
+
text = re.sub('\n', ' ', text)
|
| 246 |
+
return text.strip().lower()
|