retailpxdemo / app.py
lemmyfly's picture
[Feat] 2024 demo
2e83d6d
raw
history blame
3.76 kB
import os
import gradio as gr
from PIL import Image
from utils import run_ocr, image_to_byte_array, resize_to_megapixels, load_product_json_data, \
load_product_dataframe_data
import pandas as pd
import joblib
from dotenv import load_dotenv
load_dotenv()
# Load all the required data
model = joblib.load('model/gs1_demo_classifier.pkl')
vectorizer = joblib.load('model/gs1_demo_vectorizer.pkl')
product_info_json = load_product_json_data('product_data/')
product_info_df = load_product_dataframe_data('product_data/generated/')
password = os.getenv('password')
def process_image_ocr(image_array):
image = Image.fromarray(image_array)
image_resized = resize_to_megapixels(image, 5)
image_byte_array = image_to_byte_array(image_resized)
ocr_text_out, response_json = run_ocr(image_byte_array)
ocr_full_text = str(ocr_text_out.full_text).replace('\n', ' ')
return ocr_full_text
def extract_all(image_front_array, image_back_array):
# OCR
ocr_front = process_image_ocr(image_front_array)
ocr_back = process_image_ocr(image_back_array)
ocr_combined = f"{ocr_front} {ocr_back}"
# Predict the product
input_vectorized = vectorizer.transform([ocr_combined])[0]
prediction = model.predict(input_vectorized)[0]
# Get the product information
output_dictionary = product_info_json.get(prediction)
if output_dictionary:
nutrient_table = output_dictionary.get('NutrientTable', {})
try:
del output_dictionary['NutrientTable']
except KeyError:
pass
nutrient_table_df = pd.DataFrame(nutrient_table)
# Format output
output_list = [['product', prediction]]
for k, v in output_dictionary.items():
output_list.append([k, v])
output_dataframe = pd.DataFrame(output_list, columns=['Attribute Name', 'Attribute Value'])
output_dataframe.sort_values(by='Attribute Name', inplace=True)
if not output_dictionary:
output_dataframe = product_info_df.get(prediction, pd.DataFrame(columns= ['Attribute Name', 'Attribute Value']))
output_dataframe.sort_values(by= 'attribute', inplace = True)
try:
nutrient_table_dict = output_dataframe[output_dataframe['attribute'] == 'NutrientTable'].iloc[0].text
except IndexError:
nutrient_table_dict = {}
nutrient_table_df = pd.DataFrame(nutrient_table_dict)
output_dataframe = output_dataframe[output_dataframe["attribute"] != "NutrientTable"]
return output_dataframe, nutrient_table_df, ocr_combined
attributes_tbox = gr.Textbox(label='Attributes')
ocr_output_tbox = gr.Textbox(label='OCR Output')
nutrient_table_output = gr.Dataframe(label= 'Nutrient Table', headers=["", ""])
gr.Button.style("{color: blue}")
gr.Interface(fn=extract_all, inputs=["image", "image"], outputs=[gr.Dataframe(headers=["Attribute Name", "Attribute Value"],label= 'Extracted Attributes'),
nutrient_table_output, ocr_output_tbox],
title= 'World of Content', description="GS1 Attribute Extractor",
css="body {background-color: #F5F7FA} "
".gr-button.gr-button-primary {background-color: #0080fa; color:white; --tw-gradient-from:0} "
".gr-button.gr-button-secondary {background-color: #172533; color:white; --tw-gradient-from:0}"
" h1 {background-image: url('file=pup-logo.svg'); "
"background-size:contain; background-repeat:no-repeat; background-position:center; "
"text-indent:-999999999px} .ou![](Data PoC 3/Testing Data/40052526_0003.png)tput-markdown "
"p{text-align:center; font-size:24px}").launch()