seoulalpha / cluster_predictor.py
SyngyeonTak
cluster_predictor updates
8afc022
# cluster_predictor.py
import joblib
import pandas as pd
from openai import OpenAI
import os
import json
from huggingface_hub import hf_hub_download
# Hugging Face dataset repo์—์„œ prompt ํŒŒ์ผ ๋กœ๋“œ
PROMPT_PATH = hf_hub_download(
repo_id="Syngyeon/seoulalpha-data",
repo_type="dataset", # โœ… ๋ฐ˜๋“œ์‹œ dataset์œผ๋กœ ์ง€์ •
filename="data/prompt/custom_prompt_eng.txt"
)
FEWSHOT_PATH = hf_hub_download(
repo_id="Syngyeon/seoulalpha-data",
repo_type="dataset", # โœ… ๋ฐ˜๋“œ์‹œ dataset์œผ๋กœ ์ง€์ •
filename="data/prompt/custom_few_shot_learning_multi_language.txt"
)
# --- ์ดˆ๊ธฐ ์„ค์ • ---
client = OpenAI(api_key=os.getenv("API_KEY"))
CLUSTER_PROFILES = {
0: "๋ฌธํ™”, ์—ญ์‚ฌ, ์ž์—ฐ ํƒ๋ฐฉ์„ ์ฃผ๋ชฉ์ ์œผ๋กœ ๊ฐ€์„์— ํ•œ๊ตญ์„ ์žฌ๋ฐฉ๋ฌธํ•˜๋Š” ์—ฌํ–‰๊ฐ. ๊ธด ์ฒด๋ฅ˜ ๊ธฐ๊ฐ„ ๋™์•ˆ ์„œ์šธ๊ณผ ์—ฌ๋Ÿฌ ์ง€๋ฐฉ(๊ฒฝ๊ธฐ, ๊ฐ•์›, ๊ฒฝ์ƒ)์„ ํ•จ๊ป˜ ๋ฐฉ๋ฌธํ•˜๋ฉฐ, ๋งค์šฐ ์•Œ๋œฐํ•˜๊ฒŒ ์†Œ๋น„ํ•˜๋Š” ๊ฒฝํ–ฅ์ด ์žˆ์Œ.",
1: "ํ•œ๊ตญ์„ ์ฒ˜์Œ ๋ฐฉ๋ฌธํ•œ ์—ฌํ–‰๊ฐ. ์งง์€ ๊ธฐ๊ฐ„ ๋™์•ˆ ์„œ์šธ์—๋งŒ ๋จธ๋ฌด๋ฅด๋ฉฐ ์Œ์‹๊ณผ ๋ฏธ์‹ ํƒ๋ฐฉ์— ๊ฐ€์žฅ ํฐ ๊ด€์‹ฌ์„ ๋‘๊ณ  ์—ฌํ–‰ํ•จ. ์ˆ™๋ฐ•๋น„์— ๋น„๊ต์  ๋†’์€ ์˜ˆ์‚ฐ์„ ์‚ฌ์šฉํ•จ.",
2: "ํ•œ๊ตญ์„ ์ฒ˜์Œ ๋ฐฉ๋ฌธํ•œ ์—ฌํ–‰๊ฐ. ์งง์€ ๊ธฐ๊ฐ„ ์„œ์šธ์— ๋จธ๋ฌผ๋ฉฐ ์Œ์‹, ์‡ผํ•‘ ๋“ฑ ๋ชจ๋“  ๋ถ„์•ผ์—์„œ ์••๋„์ ์ธ ์†Œ๋น„๋ ฅ์„ ๋ณด์—ฌ์ฃผ๋Š” ๋Ÿญ์…”๋ฆฌ ์—ฌํ–‰์„ ์ฆ๊น€.",
3: "์‡ผํ•‘๊ณผ ๋ง›์ง‘ ํƒ๋ฐฉ์„ ๋ชฉ์ ์œผ๋กœ ์„œ์šธ์„ ์ž์ฃผ ์žฌ๋ฐฉ๋ฌธํ•˜๋Š” ์—ฌํ–‰๊ฐ. ๋งค์šฐ ์งง์€ ๊ธฐ๊ฐ„ ๋จธ๋ฌผ๋ฉฐ ์—ฌํ–‰ ๋ชฉ์ ์„ ์ง‘์ค‘์ ์œผ๋กœ ๋‹ฌ์„ฑํ•˜๊ณ , ์‹๋น„์— ์ง€์ถœ ๋น„์ค‘์ด ๋งค์šฐ ๋†’์Œ. ๋ฌธํ™”๋‚˜ ์ž์—ฐ๋ณด๋‹ค ์‡ผํ•‘๊ณผ ๋ฏธ์‹์— ๊ด€์‹ฌ์ด ์ง‘์ค‘๋จ.",
4: "ํ•œ๊ตญ ์—ฌํ–‰ ๊ฒฝํ—˜์ด ํ’๋ถ€ํ•œ ์žฌ๋ฐฉ๋ฌธ๊ฐ. ์„œ์šธ๋ฟ๋งŒ ์•„๋‹ˆ๋ผ ์ „๊ตญ์„ ์—ฌํ–‰ํ•˜๋ฉฐ, ํŠนํžˆ ๋‹ค์–‘ํ•œ ์ง€์—ญ์˜ ์Œ์‹์„ ์ฆ๊ธฐ๋Š” ๋ฏธ์‹ ํ™œ๋™์— ๊ด€์‹ฌ์ด ๋งค์šฐ ๋†’์Œ.",
5: "ํ•œ๊ตญ์„ ์ฒ˜์Œ ๋ฐฉ๋ฌธํ•˜๋Š” ์—ฌํ–‰๊ฐ. ๊ธด ๊ธฐ๊ฐ„ ๋™์•ˆ ๋จธ๋ฌด๋ฅด๋ฉฐ ์„œ์šธ์„ ๋„˜์–ด ์ง€๋ฐฉ, ํŠนํžˆ ๊ฒฝ์ƒ๋„ ์ง€์—ญ์˜ ์ž์—ฐ ๊ฒฝ๊ด€๊ณผ ๋ฌธํ™” ์œ ์‚ฐ์„ ๊นŠ์ด ์žˆ๊ฒŒ ํƒํ—˜ํ•˜๋Š” ๊ฒƒ์— ๊ด€์‹ฌ์ด ์••๋„์ ์œผ๋กœ ๋†’์Œ. ์˜ˆ์‚ฐ์€ ๋น„๊ต์  ์ ๊ฒŒ ์‚ฌ์šฉํ•จ.",
6: "ํ•œ๊ตญ์„ ์ฒ˜์Œ ๋ฐฉ๋ฌธํ•˜๋Š” ์—ฌํ–‰๊ฐ. ๊ธด ๊ธฐ๊ฐ„ ๋™์•ˆ ์ง€๋ฐฉ, ํŠนํžˆ ๊ฒฝ์ƒ๋„๋ฅผ ์—ฌํ–‰ํ•˜๋ฉฐ ํ•œ๊ตญ์˜ ์ž์—ฐ ๊ฒฝ๊ด€๊ณผ ๋ฌธํ™” ์œ ์‚ฐ์— ๋งค์šฐ ๋†’์€ ๋งŒ์กฑ๋„์™€ ๊นŠ์€ ๊ฐ๋ช…์„ ๋А๋‚Œ. ์žฌ๋ฐฉ๋ฌธ ์˜ํ–ฅ๋„ ๋†’์€ ์ด์ƒ์ ์ธ ํƒ๋ฐฉํ˜• ์—ฌํ–‰๊ฐ."
}
# --- ๋ชจ๋ธ ๋ฐ ๋ฐ์ดํ„ฐ ๋กœ๋“œ ---
try:
preprocessor = joblib.load('./models/preprocessor.joblib')
pca = joblib.load('./models/pca.joblib')
kmeans = joblib.load('./models/kmeans.joblib')
imputation_base_data = pd.read_csv('./models/imputation_base_data.csv', encoding='utf-8-sig')
with open('./models/variable_weights.json', 'r', encoding='utf-8') as f:
VARIABLE_WEIGHTS = json.load(f)
except FileNotFoundError:
print("๋ชจ๋ธ ํŒŒ์ผ์ด ์—†์Šต๋‹ˆ๋‹ค. train_model.py๋ฅผ ๋จผ์ € ์‹คํ–‰ํ•ด์ฃผ์„ธ์š”.")
preprocessor, pca, kmeans, imputation_base_data = None, None, None, None
# ๋ณ€์ˆ˜ ์ •์˜
categorical_cols = ['country', 'gender', 'age', 'revisit_indicator', 'visit_local_indicator', 'planned_activity']
numerical_cols = ['stay_duration', 'accommodation_percent', 'food_percent', 'shopping_percent', 'food', 'landscape', 'heritage', 'language', 'safety', 'budget', 'accommodation', 'transport', 'navigation']
used_variables = categorical_cols + numerical_cols
def query_llm_for_variables(user_query, use_prompt=True, use_fewshot=True):
prompt_parts = []
if use_prompt:
with open(PROMPT_PATH, "r", encoding="utf-8") as f:
custom_prompt = f.read()
prompt_parts.append(custom_prompt)
if use_fewshot:
with open(FEWSHOT_PATH, "r", encoding="utf-8") as f:
few_shot_examples = f.read()
prompt_parts.append(few_shot_examples)
full_prompt = "\n\n".join(prompt_parts)
messages = [
{"role": "system", "content": full_prompt},
{"role": "user", "content": user_query}
]
try:
response = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=messages,
response_format={"type": "json_object"} # tsy ์ถ”๊ฐ€: JSON ์‘๋‹ต ํ˜•์‹์„ ๊ฐ•์ œ
)
content = response.choices[0].message.content.strip()
return json.loads(content)
except Exception as e:
print("[ํŒŒ์‹ฑ ์‹คํŒจ]", e)
return {}
def impute_with_user_subgroup(user_input_dict, df_base=imputation_base_data):
known_info = {k: v for k, v in user_input_dict.items() if v is not None}
filtered_df = df_base.copy()
for key, val in known_info.items():
if key in filtered_df.columns:
filtered_df = filtered_df[filtered_df[key].astype(str) == str(val)]
imputed = {}
for var in used_variables:
if user_input_dict.get(var) is not None:
imputed[var] = user_input_dict[var]
else:
if not filtered_df.empty:
if var in numerical_cols: imputed[var] = filtered_df[var].mean()
elif var in categorical_cols: imputed[var] = filtered_df[var].mode().iloc[0]
else:
if var in numerical_cols: imputed[var] = df_base[var].mean()
elif var in categorical_cols: imputed[var] = df_base[var].mode().iloc[0]
return imputed
def predict_cluster_from_query(variable_dict: dict):
# ์ด ํ•จ์ˆ˜๋Š” ๋” ์ด์ƒ LLM์„ ํ˜ธ์ถœํ•˜์ง€ ์•Š๊ณ , ์ฃผ์–ด์ง„ ์ •๋ณด๋กœ ์˜ˆ์ธก๋งŒ ์ˆ˜ํ–‰
if not variable_dict: return None
completed_input = impute_with_user_subgroup(variable_dict)
df = pd.DataFrame([completed_input])
for col in categorical_cols:
if col in df.columns: df[col] = df[col].astype(str)
for col in numerical_cols:
if col in df.columns: df[col] = pd.to_numeric(df[col], errors='coerce')
try:
X_processed = preprocessor.transform(df)
X_pca = pca.transform(X_processed)
return kmeans.predict(X_pca)[0]
except Exception as e:
print(f"[ํด๋Ÿฌ์Šคํ„ฐ ์˜ˆ์ธก ์‹คํŒจ] {e}")
return None
# ==================== ์‹ ๊ทœ ์ถ”๊ฐ€: ํ—ฌํผ ํ•จ์ˆ˜ ====================
def _calculate_info_score(extracted_vars):
"""์ถ”์ถœ๋œ ๋ณ€์ˆ˜๋“ค์˜ ๊ฐ€์ค‘์น˜ ํ•ฉ์œผ๋กœ ์ •๋ณด ์ถฉ๋ถ„๋„ ์ ์ˆ˜๋ฅผ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค."""
if not VARIABLE_WEIGHTS: return 0.0
current_score = sum(VARIABLE_WEIGHTS.get(var, 0) for var, value in extracted_vars.items() if value is not None)
print(f"์ •๋ณด ์ถฉ๋ถ„๋„ ์ ์ˆ˜: {current_score:.4f}")
return current_score
def _generate_clarifying_question(user_query, context):
variable_map = {
'revisit_indicator': '์ด๋ฒˆ์ด ํ•œ๊ตญ ์ฒซ ๋ฐฉ๋ฌธ์ธ์ง€, ํ˜น์€ ์ด์ „์— ํ•œ๊ตญ์„ ๋ฐฉ๋ฌธํ•œ ์ ์ด ์žˆ๋Š”์ง€',
'visit_local_indicator': '์ˆ˜๋„๊ถŒ(์„œ์šธ/๊ฒฝ๊ธฐ/์ธ์ฒœ) ์™ธ ๋‹ค๋ฅธ ์ง€์—ญ์„ ๋ฐฉ๋ฌธํ•  ๊ณ„ํš์ด ์žˆ๋Š”์ง€',
'stay_duration': 'ํ•œ๊ตญ ์—ฌํ–‰ ๊ธฐ๊ฐ„',
'planned_activity': 'ํ•œ๊ตญ ์—ฌํ–‰์„ ํ•˜๊ธฐ์œ„ํ•ด ๊ณ„ํšํ•œ ํ™œ๋™'
}
missing_vars = []
if VARIABLE_WEIGHTS:
sorted_vars = sorted(VARIABLE_WEIGHTS.keys(), key=lambda k: VARIABLE_WEIGHTS[k], reverse=True)
for var in sorted_vars:
if context.get(var) is None and var in variable_map:
missing_vars.append(variable_map[var])
if not missing_vars:
return "์—ฌํ–‰์— ๋Œ€ํ•ด ์กฐ๊ธˆ๋งŒ ๋” ์ž์„ธํžˆ ๋ง์”€ํ•ด์ฃผ์‹œ๊ฒ ์–ด์š”?"
question_prompt = f"""๋‹น์‹ ์€ ์นœ์ ˆํ•œ ์—ฌํ–‰ ํ”Œ๋ž˜๋„ˆ์ž…๋‹ˆ๋‹ค.
์‚ฌ์šฉ์ž๊ฐ€ ์•„๋ž˜์™€ ๊ฐ™์ด ์งˆ๋ฌธํ–ˆ์Šต๋‹ˆ๋‹ค.
์‚ฌ์šฉ์ž ์งˆ๋ฌธ: "{user_query}"
์‚ฌ์šฉ์ž ๋งž์ถค ์ถ”์ฒœ์„ ์œ„ํ•ด '{', '.join(missing_vars[:2])}' ์ •๋ณด๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.
์‚ฌ์šฉ์ž์˜ ์งˆ๋ฌธ ๋งฅ๋ฝ์— ๋งž์ถฐ ์ž์—ฐ์Šค๋Ÿฝ๊ฒŒ ์งˆ๋ฌธ์„ ํ•œ ๋ฌธ์žฅ์œผ๋กœ ๋งŒ๋“ค์–ด์ฃผ์„ธ์š”."""
try:
response = client.chat.completions.create(model="gpt-3.5-turbo", messages=[{"role": "system", "content": question_prompt}])
return response.choices[0].message.content
except Exception:
return f"ํ˜น์‹œ ๊ณ„ํš ์ค‘์ธ {missing_vars[0]}์— ๋Œ€ํ•ด ์กฐ๊ธˆ ๋” ์•Œ๋ ค์ฃผ์‹ค ์ˆ˜ ์žˆ๋‚˜์š”?"
# --- ๋Œ€ํ‘œ ์‹คํ–‰ ํ•จ์ˆ˜ (์žฌ์„ค๊ณ„) ---
def get_user_cluster(user_query: str, previous_context: dict = None):
if preprocessor is None or pca is None or kmeans is None or imputation_base_data.empty:
return None, None
#if not all([preprocessor, pca, kmeans, imputation_base_data, VARIABLE_WEIGHTS]):
# return "FAIL", "ํ•„์ˆ˜ ๋ชจ๋ธ/๋ฐ์ดํ„ฐ ํŒŒ์ผ์ด ๋กœ๋“œ๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค."
newly_extracted_vars = query_llm_for_variables(user_query)
current_context = previous_context.copy() if previous_context else {}
current_context.update({k: v for k, v in newly_extracted_vars.items() if v is not None})
score = _calculate_info_score(current_context)
if score > 0.50:
#print("โœ… ์ •๋ณด๊ฐ€ ์ถฉ๋ถ„ํ•˜์—ฌ ํด๋Ÿฌ์Šคํ„ฐ๋ง์„ ์ง„ํ–‰ํ•ฉ๋‹ˆ๋‹ค.")
cluster_label = predict_cluster_from_query(current_context)
if cluster_label is not None:
profile = CLUSTER_PROFILES.get(cluster_label, "์ •์˜๋˜์ง€ ์•Š์€ ํด๋Ÿฌ์Šคํ„ฐ์ž…๋‹ˆ๋‹ค.")
return "SUCCESS", (cluster_label, profile)
else:
return "FAIL", "ํด๋Ÿฌ์Šคํ„ฐ ์˜ˆ์ธก์— ์‹คํŒจํ–ˆ์Šต๋‹ˆ๋‹ค."
else:
#print("โš ๏ธ ์ •๋ณด๊ฐ€ ๋ถˆ์ถฉ๋ถ„ํ•˜์—ฌ ์‚ฌ์šฉ์ž์—๊ฒŒ ์žฌ์งˆ์˜ํ•ฉ๋‹ˆ๋‹ค.")
question = _generate_clarifying_question(user_query, current_context)
return "RETRY_WITH_QUESTION", (question, current_context)