HCIlab-SMWU's picture
Update app.py
39d47ff verified
import gradio as gr
import pandas as pd
from sklearn.cluster import KMeans
import numpy as np
import os
from huggingface_hub import hf_hub_download
# Hugging Face ํ† ํฐ ์„ค์ •
hf_token = os.getenv("HF_TOKEN")
if hf_token is None:
print("ํ† ํฐ์ด ์„ค์ •๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค. Repository Secrets์— HF_TOKEN์ด ์žˆ๋Š”์ง€ ํ™•์ธํ•˜์„ธ์š”.")
else:
print("ํ† ํฐ์ด ์„ค์ •๋˜์—ˆ์Šต๋‹ˆ๋‹ค. ํŒŒ์ผ์„ ๋‹ค์šด๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค.")
# ์ •ํ™•ํ•œ repo_id, ํŒŒ์ผ ์ด๋ฆ„, repo_type ์„ค์ •
file_path = hf_hub_download(
repo_id="HCIlab-SMWU/final_petfood_dataset",
filename="new_data_with_features.xlsx",
repo_type="dataset",
use_auth_token=hf_token
)
print(f"ํŒŒ์ผ์ด ๋‹ค์šด๋กœ๋“œ๋˜์—ˆ์Šต๋‹ˆ๋‹ค: {file_path}")
# Excel ํŒŒ์ผ์„ pandas DataFrame์œผ๋กœ ์ฝ๊ธฐ
df = pd.read_excel(file_path)
# DataFrame ํ™•์ธ
print(df.head())
# ํ–‰ ๊ฐœ์ˆ˜ ํ™•์ธ
row_count = df.shape[0]
print(f"DataFrame์˜ ํ–‰ ๊ฐœ์ˆ˜๋Š”: {row_count}")
key1 = os.getenv("secret1")
key2 = os.getenv("secret2")
# ํ•„ํ„ฐ๋ง ํ•จ์ˆ˜
def filter_feed(age_input, allergies_input, health_concerns_input, pet_type_input, sort_option):
data = df.copy()
error_message = "" # ์—๋Ÿฌ ๋ฉ”์‹œ์ง€ ์ดˆ๊ธฐํ™”
# ๋ฐ˜๋ ค๋™๋ฌผ ์ข…๋ฅ˜ ํ•„ํ„ฐ๋ง
if pet_type_input:
data = data[data['์ข…'] == pet_type_input]
# ์—ฐ๋ น ํ•„ํ„ฐ๋ง
if age_input == "1์‚ด ๋ฏธ๋งŒ":
age_filter = ["ํผํ”ผ", "ํ‚คํŠผ", "์ „์—ฐ๋ น"]
elif age_input == "1์‚ด ์ด์ƒ, 7์‚ด ์ดํ•˜":
age_filter = ["์–ด๋œํŠธ", "์ „์—ฐ๋ น"]
else:
age_filter = ["์‹œ๋‹ˆ์–ด", "์ „์—ฐ๋ น"]
data = data[data['๊ธ‰์—ฌ๋Œ€์ƒ'].str.contains('|'.join(age_filter), na=False)]
if allergies_input:
allergy_pattern = '|'.join(allergies_input)
data = data[~data[key2].str.contains(allergy_pattern, na=False)]
data = data[data[key2].notna() & (data[key2] != "")]
# ๊ฑด๊ฐ• ๊ณ ๋ฏผ ํ•„ํ„ฐ๋ง
health_mapping = {
"์น˜์•„/๊ตฌ๊ฐ•": ["์น˜์„์ œ๊ฑฐ", "๊ตฌ๊ฐ•๊ด€๋ฆฌ"],
"๋ผˆ/๊ด€์ ˆ": ["๊ด€์ ˆ๊ฐ•ํ™”"],
"ํ”ผ๋ถ€/๋ชจ์งˆ": ["ํ”ผ๋ชจ๊ด€๋ฆฌ"],
"์•Œ๋Ÿฌ์ง€": ["์ €์•Œ๋Ÿฌ์ง€"],
"๋น„๋งŒ": ["๋‹ค์ด์–ดํŠธ/์ค‘์„ฑํ™”", "์ฒด์ค‘์œ ์ง€"],
"๋น„๋‡จ๊ธฐ": ["์œ ๋ฆฌ๋„ˆ๋ฆฌ(๋น„๋‡จ๊ณ„)", "๊ฒฐ์„์˜ˆ๋ฐฉ", "์‹ ์žฅ/์š”๋กœ", "์Œ์ˆ˜๋Ÿ‰์ฆ์ง„"],
"๋ˆˆ": ["๋ˆˆ๊ฑด๊ฐ•"],
"์†Œํ™”๊ธฐ": ["์†Œํ™”๊ฐœ์„ "],
"ํ–‰๋™": ["๋ถ„๋ฆฌ๋ถˆ์•ˆํ•ด์†Œ", "์ŠคํŠธ๋ ˆ์Šค์™„ํ™”"],
"์‹ฌ์žฅ": ["์‹ฌ์žฅ๊ฑด๊ฐ•"],
"ํ˜ธํก๊ธฐ": ["ํ˜ธํก๊ธฐ๊ด€๋ฆฌ"],
"๋…ธํ™”": ["ํ•ญ์‚ฐํ™”"],
"ํ—ค์–ด๋ณผ": ["ํ—ค์–ด๋ณผ"]
}
health_patterns = []
# ๊ฑด๊ฐ• ๊ณ ๋ฏผ์ด 3๊ฐœ๋ฅผ ์ดˆ๊ณผํ•  ๊ฒฝ์šฐ
if len(health_concerns_input) > 3:
error_message = "๊ฑด๊ฐ• ๊ณ ๋ฏผ์€ ์ตœ๋Œ€ 3๊ฐœ๊นŒ์ง€ ์„ ํƒ ๊ฐ€๋Šฅํ•ฉ๋‹ˆ๋‹ค."
return pd.DataFrame(), error_message # ๋นˆ ๋ฐ์ดํ„ฐํ”„๋ ˆ์ž„๊ณผ ์—๋Ÿฌ ๋ฉ”์‹œ์ง€ ๋ฐ˜ํ™˜
for concern in health_concerns_input:
health_patterns.append(health_mapping.get(concern, []))
# ๋ชจ๋“  ๊ฑด๊ฐ• ๊ณ ๋ฏผ ์กฐ๊ฑด์„ ๋งŒ์กฑํ•˜๋Š” ๋ฐ์ดํ„ฐ ํ•„ํ„ฐ๋ง
filtered_data = data.copy()
for patterns in health_patterns:
if patterns: # ๋นˆ ํŒจํ„ด์ด ์•„๋‹ ๊ฒฝ์šฐ
filtered_data = filtered_data[filtered_data[key1].str.contains('|'.join(patterns), na=False)]
# ๊ฒฐ๊ณผ์— ์ข…, ๊ธ‰์—ฌ๋Œ€์ƒ, ๊ธฐ๋Šฅ, ์ฃผ์›๋ฃŒ ์ถ”๊ฐ€
results = filtered_data[['Cleaned_Product_Name', '์ข…', '๊ธ‰์—ฌ๋Œ€์ƒ', '๊ธฐ๋Šฅ', '์ฃผ์›๋ฃŒ']]
# KMeans๋ฅผ ์ ์šฉํ•˜์—ฌ ํด๋Ÿฌ์Šคํ„ฐ๋ง
if len(results) > 10:
# ์›-ํ•ซ ์ธ์ฝ”๋”ฉ ์ˆ˜ํ–‰
one_hot_ingredients = pd.get_dummies(results[key2].str.split('|').explode()).groupby(level=0).max()
one_hot_features = pd.get_dummies(results[key1].str.split('|').explode()).groupby(level=0).max()
# ์›-ํ•ซ ์ธ์ฝ”๋”ฉ๋œ ๋ฐ์ดํ„ฐ์™€ ๊ธฐ์กด ๋ฐ์ดํ„ฐ๋ฅผ ๊ฒฐํ•ฉ
features = pd.concat([one_hot_ingredients, one_hot_features], axis=1)
# KMeans ๋ชจ๋ธ ์ƒ์„ฑ ๋ฐ ํ•™์Šต
kmeans = KMeans(n_clusters=10, random_state=0)
kmeans.fit(features)
# ๊ฐ ํด๋Ÿฌ์Šคํ„ฐ์˜ ์ค‘์‹ฌ์ ์— ๋Œ€ํ•œ ๊ฑฐ๋ฆฌ ๊ณ„์‚ฐ
distances = kmeans.transform(features)
# ๊ฐ ๋ฐ์ดํ„ฐ ํฌ์ธํŠธ๊ฐ€ ์†ํ•œ ํด๋Ÿฌ์Šคํ„ฐ์™€ ๊ทธ ์ค‘์‹ฌ์ ๊ณผ์˜ ๊ฑฐ๋ฆฌ ์ถ”๊ฐ€
results['cluster'] = kmeans.labels_
results['distance_to_centroid'] = [distances[i][label] for i, label in enumerate(kmeans.labels_)]
# ๊ฐ ํด๋Ÿฌ์Šคํ„ฐ์—์„œ ์ค‘์‹ฌ์ ์— ๊ฐ€์žฅ ๊ฐ€๊นŒ์šด ๋ฐ์ดํ„ฐ๋งŒ ์„ ํƒ
closest_to_centroid = results.loc[results.groupby('cluster')['distance_to_centroid'].idxmin()]
# ์ •๋ ฌ ์˜ต์…˜์— ๋”ฐ๋ผ ์ •๋ ฌ ์ˆ˜ํ–‰
if sort_option == "๊ฐ€๋‚˜๋‹ค์ˆœ":
closest_to_centroid = closest_to_centroid.sort_values(by=['Cleaned_Product_Name']).reset_index(drop=True)
results = closest_to_centroid[['Cleaned_Product_Name', '์ข…', '๊ธ‰์—ฌ๋Œ€์ƒ', '๊ธฐ๋Šฅ', '์ฃผ์›๋ฃŒ']]
# 'Cleaned_Product_Name' ์—ด ์ด๋ฆ„์„ '์‚ฌ๋ฃŒ์ด๋ฆ„'์œผ๋กœ ๋ณ€๊ฒฝ
results = results.rename(columns={'Cleaned_Product_Name': '์‚ฌ๋ฃŒ์ด๋ฆ„'})
# ์—๋Ÿฌ ๋ฉ”์‹œ์ง€ ์ฒ˜๋ฆฌ
if results.empty:
error_message = "์•Œ๋Ÿฌ์ง€์™€ ๊ฑด๊ฐ•๊ณ ๋ฏผ ์กฐ๊ฑด์„ ๋ชจ๋‘ ๋งŒ์กฑํ•˜๋Š” ์‚ฌ๋ฃŒ๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค."
return results, error_message
# Gradio ์ธํ„ฐํŽ˜์ด์Šค ์„ค์ •
with gr.Blocks() as demo:
gr.Markdown("# ๋ฐ˜๋ ค๋™๋ฌผ ์‚ฌ๋ฃŒ ์ถ”์ฒœ")
age_input = gr.Dropdown(["1์‚ด ๋ฏธ๋งŒ", "1์‚ด ์ด์ƒ, 7์‚ด ์ดํ•˜", "7์‚ด ์ด์ƒ"], label="์—ฐ๋ น๋Œ€")
allergies_input = gr.CheckboxGroup(
["์†Œ", "๋ผ์ง€", "๋‹ญ", "์˜ค๋ฆฌ", "์–‘", "์น ๋ฉด์กฐ", "์ƒ์„ /ํ•ด์‚ฐ๋ฌผ", "์‚ฌ์Šด", "์—ฐ์–ด", "์น˜์ฆˆ/์œ ์ง€๋ฐฉ", "์ฐธ์น˜", "๋ฐ€", "์Œ€","๊ณ ๊ตฌ๋งˆ","๊ณก๋ฌผ","๊ณค์ถฉ","๊ณผ์ผ/์•ผ์ฑ„","๋ถ์–ด","์ฒญ์–ด"],
label="์ฃผ์›๋ฃŒ ์•Œ๋Ÿฌ์ง€"
)
health_concerns_input = gr.CheckboxGroup(
["์น˜์•„/๊ตฌ๊ฐ•", "๋ผˆ/๊ด€์ ˆ", "ํ”ผ๋ถ€/๋ชจ์งˆ", "์•Œ๋Ÿฌ์ง€", "๋น„๋งŒ", "๋น„๋‡จ๊ธฐ", "๋ˆˆ", "์†Œํ™”๊ธฐ", "ํ–‰๋™", "์‹ฌ์žฅ", "ํ˜ธํก๊ธฐ", "๋…ธํ™”", "ํ—ค์–ด๋ณผ"],
label="๊ฑด๊ฐ• ๊ณ ๋ฏผ (์ตœ๋Œ€ 3๊ฐœ ์„ ํƒ ๊ฐ€๋Šฅ)"
)
pet_type_input = gr.Dropdown(["๊ฐ•์•„์ง€", "๊ณ ์–‘์ด"], label="๋ฐ˜๋ ค๋™๋ฌผ ์ข…๋ฅ˜")
sort_option = gr.Radio(["์ถ”์ฒœ์ˆœ", "๊ฐ€๋‚˜๋‹ค์ˆœ"], label="์ •๋ ฌ ๋ฐฉ์‹", value="์ถ”์ฒœ์ˆœ")
submit_button = gr.Button("์ถ”์ฒœ ์‚ฌ๋ฃŒ ๋ณด๊ธฐ")
output = gr.Dataframe()
error_output = gr.Textbox(label="์—๋Ÿฌ ๋ฉ”์‹œ์ง€", interactive=False)
sort_option.change(
fn=filter_feed,
inputs=[age_input, allergies_input, health_concerns_input, pet_type_input, sort_option],
outputs=[output, error_output]
)
submit_button.click(
fn=filter_feed,
inputs=[age_input, allergies_input, health_concerns_input, pet_type_input, sort_option],
outputs=[output, error_output]
)
def check_health_concerns(health_concerns_input):
if len(health_concerns_input) > 3:
return "๊ฑด๊ฐ• ๊ณ ๋ฏผ์€ ์ตœ๋Œ€ 3๊ฐœ๊นŒ์ง€ ์„ ํƒ ๊ฐ€๋Šฅํ•ฉ๋‹ˆ๋‹ค."
return ""
health_concerns_input.change(fn=check_health_concerns, inputs=health_concerns_input, outputs=error_output)
demo.launch(share=True)