Spaces:
Sleeping
Sleeping
| # cluster_predictor.py | |
| import joblib | |
| import pandas as pd | |
| from openai import OpenAI | |
| import os | |
| import json | |
| from huggingface_hub import hf_hub_download | |
| # Hugging Face dataset repo์์ prompt ํ์ผ ๋ก๋ | |
| PROMPT_PATH = hf_hub_download( | |
| repo_id="Syngyeon/seoulalpha-data", | |
| repo_type="dataset", # โ ๋ฐ๋์ dataset์ผ๋ก ์ง์ | |
| filename="data/prompt/custom_prompt_eng.txt" | |
| ) | |
| FEWSHOT_PATH = hf_hub_download( | |
| repo_id="Syngyeon/seoulalpha-data", | |
| repo_type="dataset", # โ ๋ฐ๋์ dataset์ผ๋ก ์ง์ | |
| filename="data/prompt/custom_few_shot_learning_multi_language.txt" | |
| ) | |
| # --- ์ด๊ธฐ ์ค์ --- | |
| client = OpenAI(api_key=os.getenv("API_KEY")) | |
| CLUSTER_PROFILES = { | |
| 0: "๋ฌธํ, ์ญ์ฌ, ์์ฐ ํ๋ฐฉ์ ์ฃผ๋ชฉ์ ์ผ๋ก ๊ฐ์์ ํ๊ตญ์ ์ฌ๋ฐฉ๋ฌธํ๋ ์ฌํ๊ฐ. ๊ธด ์ฒด๋ฅ ๊ธฐ๊ฐ ๋์ ์์ธ๊ณผ ์ฌ๋ฌ ์ง๋ฐฉ(๊ฒฝ๊ธฐ, ๊ฐ์, ๊ฒฝ์)์ ํจ๊ป ๋ฐฉ๋ฌธํ๋ฉฐ, ๋งค์ฐ ์๋ฐํ๊ฒ ์๋นํ๋ ๊ฒฝํฅ์ด ์์.", | |
| 1: "ํ๊ตญ์ ์ฒ์ ๋ฐฉ๋ฌธํ ์ฌํ๊ฐ. ์งง์ ๊ธฐ๊ฐ ๋์ ์์ธ์๋ง ๋จธ๋ฌด๋ฅด๋ฉฐ ์์๊ณผ ๋ฏธ์ ํ๋ฐฉ์ ๊ฐ์ฅ ํฐ ๊ด์ฌ์ ๋๊ณ ์ฌํํจ. ์๋ฐ๋น์ ๋น๊ต์ ๋์ ์์ฐ์ ์ฌ์ฉํจ.", | |
| 2: "ํ๊ตญ์ ์ฒ์ ๋ฐฉ๋ฌธํ ์ฌํ๊ฐ. ์งง์ ๊ธฐ๊ฐ ์์ธ์ ๋จธ๋ฌผ๋ฉฐ ์์, ์ผํ ๋ฑ ๋ชจ๋ ๋ถ์ผ์์ ์๋์ ์ธ ์๋น๋ ฅ์ ๋ณด์ฌ์ฃผ๋ ๋ญ์ ๋ฆฌ ์ฌํ์ ์ฆ๊น.", | |
| 3: "์ผํ๊ณผ ๋ง์ง ํ๋ฐฉ์ ๋ชฉ์ ์ผ๋ก ์์ธ์ ์์ฃผ ์ฌ๋ฐฉ๋ฌธํ๋ ์ฌํ๊ฐ. ๋งค์ฐ ์งง์ ๊ธฐ๊ฐ ๋จธ๋ฌผ๋ฉฐ ์ฌํ ๋ชฉ์ ์ ์ง์ค์ ์ผ๋ก ๋ฌ์ฑํ๊ณ , ์๋น์ ์ง์ถ ๋น์ค์ด ๋งค์ฐ ๋์. ๋ฌธํ๋ ์์ฐ๋ณด๋ค ์ผํ๊ณผ ๋ฏธ์์ ๊ด์ฌ์ด ์ง์ค๋จ.", | |
| 4: "ํ๊ตญ ์ฌํ ๊ฒฝํ์ด ํ๋ถํ ์ฌ๋ฐฉ๋ฌธ๊ฐ. ์์ธ๋ฟ๋ง ์๋๋ผ ์ ๊ตญ์ ์ฌํํ๋ฉฐ, ํนํ ๋ค์ํ ์ง์ญ์ ์์์ ์ฆ๊ธฐ๋ ๋ฏธ์ ํ๋์ ๊ด์ฌ์ด ๋งค์ฐ ๋์.", | |
| 5: "ํ๊ตญ์ ์ฒ์ ๋ฐฉ๋ฌธํ๋ ์ฌํ๊ฐ. ๊ธด ๊ธฐ๊ฐ ๋์ ๋จธ๋ฌด๋ฅด๋ฉฐ ์์ธ์ ๋์ด ์ง๋ฐฉ, ํนํ ๊ฒฝ์๋ ์ง์ญ์ ์์ฐ ๊ฒฝ๊ด๊ณผ ๋ฌธํ ์ ์ฐ์ ๊น์ด ์๊ฒ ํํํ๋ ๊ฒ์ ๊ด์ฌ์ด ์๋์ ์ผ๋ก ๋์. ์์ฐ์ ๋น๊ต์ ์ ๊ฒ ์ฌ์ฉํจ.", | |
| 6: "ํ๊ตญ์ ์ฒ์ ๋ฐฉ๋ฌธํ๋ ์ฌํ๊ฐ. ๊ธด ๊ธฐ๊ฐ ๋์ ์ง๋ฐฉ, ํนํ ๊ฒฝ์๋๋ฅผ ์ฌํํ๋ฉฐ ํ๊ตญ์ ์์ฐ ๊ฒฝ๊ด๊ณผ ๋ฌธํ ์ ์ฐ์ ๋งค์ฐ ๋์ ๋ง์กฑ๋์ ๊น์ ๊ฐ๋ช ์ ๋๋. ์ฌ๋ฐฉ๋ฌธ ์ํฅ๋ ๋์ ์ด์์ ์ธ ํ๋ฐฉํ ์ฌํ๊ฐ." | |
| } | |
| # --- ๋ชจ๋ธ ๋ฐ ๋ฐ์ดํฐ ๋ก๋ --- | |
| try: | |
| preprocessor = joblib.load('./models/preprocessor.joblib') | |
| pca = joblib.load('./models/pca.joblib') | |
| kmeans = joblib.load('./models/kmeans.joblib') | |
| imputation_base_data = pd.read_csv('./models/imputation_base_data.csv', encoding='utf-8-sig') | |
| with open('./models/variable_weights.json', 'r', encoding='utf-8') as f: | |
| VARIABLE_WEIGHTS = json.load(f) | |
| except FileNotFoundError: | |
| print("๋ชจ๋ธ ํ์ผ์ด ์์ต๋๋ค. train_model.py๋ฅผ ๋จผ์ ์คํํด์ฃผ์ธ์.") | |
| preprocessor, pca, kmeans, imputation_base_data = None, None, None, None | |
| # ๋ณ์ ์ ์ | |
| categorical_cols = ['country', 'gender', 'age', 'revisit_indicator', 'visit_local_indicator', 'planned_activity'] | |
| numerical_cols = ['stay_duration', 'accommodation_percent', 'food_percent', 'shopping_percent', 'food', 'landscape', 'heritage', 'language', 'safety', 'budget', 'accommodation', 'transport', 'navigation'] | |
| used_variables = categorical_cols + numerical_cols | |
| def query_llm_for_variables(user_query, use_prompt=True, use_fewshot=True): | |
| prompt_parts = [] | |
| if use_prompt: | |
| with open(PROMPT_PATH, "r", encoding="utf-8") as f: | |
| custom_prompt = f.read() | |
| prompt_parts.append(custom_prompt) | |
| if use_fewshot: | |
| with open(FEWSHOT_PATH, "r", encoding="utf-8") as f: | |
| few_shot_examples = f.read() | |
| prompt_parts.append(few_shot_examples) | |
| full_prompt = "\n\n".join(prompt_parts) | |
| messages = [ | |
| {"role": "system", "content": full_prompt}, | |
| {"role": "user", "content": user_query} | |
| ] | |
| try: | |
| response = client.chat.completions.create( | |
| model="gpt-3.5-turbo", | |
| messages=messages, | |
| response_format={"type": "json_object"} # tsy ์ถ๊ฐ: JSON ์๋ต ํ์์ ๊ฐ์ | |
| ) | |
| content = response.choices[0].message.content.strip() | |
| return json.loads(content) | |
| except Exception as e: | |
| print("[ํ์ฑ ์คํจ]", e) | |
| return {} | |
| def impute_with_user_subgroup(user_input_dict, df_base=imputation_base_data): | |
| known_info = {k: v for k, v in user_input_dict.items() if v is not None} | |
| filtered_df = df_base.copy() | |
| for key, val in known_info.items(): | |
| if key in filtered_df.columns: | |
| filtered_df = filtered_df[filtered_df[key].astype(str) == str(val)] | |
| imputed = {} | |
| for var in used_variables: | |
| if user_input_dict.get(var) is not None: | |
| imputed[var] = user_input_dict[var] | |
| else: | |
| if not filtered_df.empty: | |
| if var in numerical_cols: imputed[var] = filtered_df[var].mean() | |
| elif var in categorical_cols: imputed[var] = filtered_df[var].mode().iloc[0] | |
| else: | |
| if var in numerical_cols: imputed[var] = df_base[var].mean() | |
| elif var in categorical_cols: imputed[var] = df_base[var].mode().iloc[0] | |
| return imputed | |
| def predict_cluster_from_query(variable_dict: dict): | |
| # ์ด ํจ์๋ ๋ ์ด์ LLM์ ํธ์ถํ์ง ์๊ณ , ์ฃผ์ด์ง ์ ๋ณด๋ก ์์ธก๋ง ์ํ | |
| if not variable_dict: return None | |
| completed_input = impute_with_user_subgroup(variable_dict) | |
| df = pd.DataFrame([completed_input]) | |
| for col in categorical_cols: | |
| if col in df.columns: df[col] = df[col].astype(str) | |
| for col in numerical_cols: | |
| if col in df.columns: df[col] = pd.to_numeric(df[col], errors='coerce') | |
| try: | |
| X_processed = preprocessor.transform(df) | |
| X_pca = pca.transform(X_processed) | |
| return kmeans.predict(X_pca)[0] | |
| except Exception as e: | |
| print(f"[ํด๋ฌ์คํฐ ์์ธก ์คํจ] {e}") | |
| return None | |
| # ==================== ์ ๊ท ์ถ๊ฐ: ํฌํผ ํจ์ ==================== | |
| def _calculate_info_score(extracted_vars): | |
| """์ถ์ถ๋ ๋ณ์๋ค์ ๊ฐ์ค์น ํฉ์ผ๋ก ์ ๋ณด ์ถฉ๋ถ๋ ์ ์๋ฅผ ๊ณ์ฐํฉ๋๋ค.""" | |
| if not VARIABLE_WEIGHTS: return 0.0 | |
| current_score = sum(VARIABLE_WEIGHTS.get(var, 0) for var, value in extracted_vars.items() if value is not None) | |
| print(f"์ ๋ณด ์ถฉ๋ถ๋ ์ ์: {current_score:.4f}") | |
| return current_score | |
| def _generate_clarifying_question(user_query, context): | |
| variable_map = { | |
| 'revisit_indicator': '์ด๋ฒ์ด ํ๊ตญ ์ฒซ ๋ฐฉ๋ฌธ์ธ์ง, ํน์ ์ด์ ์ ํ๊ตญ์ ๋ฐฉ๋ฌธํ ์ ์ด ์๋์ง', | |
| 'visit_local_indicator': '์๋๊ถ(์์ธ/๊ฒฝ๊ธฐ/์ธ์ฒ) ์ธ ๋ค๋ฅธ ์ง์ญ์ ๋ฐฉ๋ฌธํ ๊ณํ์ด ์๋์ง', | |
| 'stay_duration': 'ํ๊ตญ ์ฌํ ๊ธฐ๊ฐ', | |
| 'planned_activity': 'ํ๊ตญ ์ฌํ์ ํ๊ธฐ์ํด ๊ณํํ ํ๋' | |
| } | |
| missing_vars = [] | |
| if VARIABLE_WEIGHTS: | |
| sorted_vars = sorted(VARIABLE_WEIGHTS.keys(), key=lambda k: VARIABLE_WEIGHTS[k], reverse=True) | |
| for var in sorted_vars: | |
| if context.get(var) is None and var in variable_map: | |
| missing_vars.append(variable_map[var]) | |
| if not missing_vars: | |
| return "์ฌํ์ ๋ํด ์กฐ๊ธ๋ง ๋ ์์ธํ ๋ง์ํด์ฃผ์๊ฒ ์ด์?" | |
| question_prompt = f"""๋น์ ์ ์น์ ํ ์ฌํ ํ๋๋์ ๋๋ค. | |
| ์ฌ์ฉ์๊ฐ ์๋์ ๊ฐ์ด ์ง๋ฌธํ์ต๋๋ค. | |
| ์ฌ์ฉ์ ์ง๋ฌธ: "{user_query}" | |
| ์ฌ์ฉ์ ๋ง์ถค ์ถ์ฒ์ ์ํด '{', '.join(missing_vars[:2])}' ์ ๋ณด๊ฐ ํ์ํฉ๋๋ค. | |
| ์ฌ์ฉ์์ ์ง๋ฌธ ๋งฅ๋ฝ์ ๋ง์ถฐ ์์ฐ์ค๋ฝ๊ฒ ์ง๋ฌธ์ ํ ๋ฌธ์ฅ์ผ๋ก ๋ง๋ค์ด์ฃผ์ธ์.""" | |
| try: | |
| response = client.chat.completions.create(model="gpt-3.5-turbo", messages=[{"role": "system", "content": question_prompt}]) | |
| return response.choices[0].message.content | |
| except Exception: | |
| return f"ํน์ ๊ณํ ์ค์ธ {missing_vars[0]}์ ๋ํด ์กฐ๊ธ ๋ ์๋ ค์ฃผ์ค ์ ์๋์?" | |
| # --- ๋ํ ์คํ ํจ์ (์ฌ์ค๊ณ) --- | |
| def get_user_cluster(user_query: str, previous_context: dict = None): | |
| if preprocessor is None or pca is None or kmeans is None or imputation_base_data.empty: | |
| return None, None | |
| #if not all([preprocessor, pca, kmeans, imputation_base_data, VARIABLE_WEIGHTS]): | |
| # return "FAIL", "ํ์ ๋ชจ๋ธ/๋ฐ์ดํฐ ํ์ผ์ด ๋ก๋๋์ง ์์์ต๋๋ค." | |
| newly_extracted_vars = query_llm_for_variables(user_query) | |
| current_context = previous_context.copy() if previous_context else {} | |
| current_context.update({k: v for k, v in newly_extracted_vars.items() if v is not None}) | |
| score = _calculate_info_score(current_context) | |
| if score > 0.50: | |
| #print("โ ์ ๋ณด๊ฐ ์ถฉ๋ถํ์ฌ ํด๋ฌ์คํฐ๋ง์ ์งํํฉ๋๋ค.") | |
| cluster_label = predict_cluster_from_query(current_context) | |
| if cluster_label is not None: | |
| profile = CLUSTER_PROFILES.get(cluster_label, "์ ์๋์ง ์์ ํด๋ฌ์คํฐ์ ๋๋ค.") | |
| return "SUCCESS", (cluster_label, profile) | |
| else: | |
| return "FAIL", "ํด๋ฌ์คํฐ ์์ธก์ ์คํจํ์ต๋๋ค." | |
| else: | |
| #print("โ ๏ธ ์ ๋ณด๊ฐ ๋ถ์ถฉ๋ถํ์ฌ ์ฌ์ฉ์์๊ฒ ์ฌ์ง์ํฉ๋๋ค.") | |
| question = _generate_clarifying_question(user_query, current_context) | |
| return "RETRY_WITH_QUESTION", (question, current_context) | |