Spaces:
Running
Running
Upload 12 files
Browse files- README.md +17 -16
- app.py +127 -70
- cluster_predictor.py +188 -0
- models/feature_importance_ranking.csv +57 -0
- models/imputation_base_data.csv +0 -0
- models/kmeans.joblib +3 -0
- models/pca.joblib +3 -0
- models/preprocessor.joblib +3 -0
- models/variable_weights.json +18 -0
- rag_retriever.py +146 -0
- region_extractor.py +97 -0
- requirements.txt +0 -0
README.md
CHANGED
|
@@ -1,16 +1,17 @@
|
|
| 1 |
-
---
|
| 2 |
-
title: Seoulalpha
|
| 3 |
-
emoji: ๐ฌ
|
| 4 |
-
colorFrom: yellow
|
| 5 |
-
colorTo: purple
|
| 6 |
-
sdk: gradio
|
| 7 |
-
sdk_version: 5.42.0
|
| 8 |
-
app_file: app.py
|
| 9 |
-
pinned: false
|
| 10 |
-
hf_oauth: true
|
| 11 |
-
hf_oauth_scopes:
|
| 12 |
-
- inference-api
|
| 13 |
-
license:
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Seoulalpha
|
| 3 |
+
emoji: ๐ฌ
|
| 4 |
+
colorFrom: yellow
|
| 5 |
+
colorTo: purple
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 5.42.0
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
hf_oauth: true
|
| 11 |
+
hf_oauth_scopes:
|
| 12 |
+
- inference-api
|
| 13 |
+
license: apache-2.0
|
| 14 |
+
short_description: Travel Spots in Korea chatbot recommender
|
| 15 |
+
---
|
| 16 |
+
|
| 17 |
+
An example chatbot using [Gradio](https://gradio.app), [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/v0.22.2/en/index), and the [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index).
|
app.py
CHANGED
|
@@ -1,70 +1,127 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
"""
|
| 15 |
-
|
| 16 |
-
"""
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app.py
|
| 2 |
+
import gradio as gr
|
| 3 |
+
from langdetect import detect
|
| 4 |
+
from deep_translator import GoogleTranslator
|
| 5 |
+
|
| 6 |
+
# ๋ชจ๋ import
|
| 7 |
+
from cluster_predictor import get_user_cluster
|
| 8 |
+
from region_extractor import extract_region_from_query
|
| 9 |
+
from rag_retriever import get_rag_recommendation
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
# ์ธ์ด ์ฝ๋ ๋งคํ (deep_translator ํธํ)
|
| 13 |
+
LANG_CODE_MAP = {
|
| 14 |
+
"zh-cn": "zh-CN",
|
| 15 |
+
"zh-tw": "zh-TW",
|
| 16 |
+
"iw": "he",
|
| 17 |
+
}
|
| 18 |
+
def normalize_lang_code(code: str) -> str:
|
| 19 |
+
return LANG_CODE_MAP.get(code.lower(), code)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# --- Gradio์ฉ ๋ํ ํจ์ ---
|
| 23 |
+
def chatbot_interface(user_input, history, state):
|
| 24 |
+
if user_input.lower() in ["์ข
๋ฃ", "exit", "quit"]:
|
| 25 |
+
return history + [[user_input, "ํ๋ก๊ทธ๋จ์ ์ข
๋ฃํฉ๋๋ค."]], state
|
| 26 |
+
|
| 27 |
+
conversation_context = state.get("conversation_context", {})
|
| 28 |
+
full_conversation = state.get("full_conversation", [])
|
| 29 |
+
|
| 30 |
+
# --- Step1: ์
๋ ฅ ์ธ์ด ๊ฐ์ง & ํ๊ตญ์ด ๋ฒ์ญ ---
|
| 31 |
+
try:
|
| 32 |
+
detected = detect(user_input) # 'en', 'ja', 'fr', 'zh-cn' ...
|
| 33 |
+
input_lang = normalize_lang_code(detected)
|
| 34 |
+
except Exception as e:
|
| 35 |
+
return history + [[user_input, f"โ ์ธ์ด ๊ฐ์ง ์ค๋ฅ: {e}"]], state
|
| 36 |
+
|
| 37 |
+
if input_lang != "ko":
|
| 38 |
+
try:
|
| 39 |
+
current_query = GoogleTranslator(source=input_lang, target="ko").translate(user_input)
|
| 40 |
+
except Exception as e:
|
| 41 |
+
return history + [[user_input, f"โ ๋ฒ์ญ ์ค๋ฅ: {e}"]], state
|
| 42 |
+
else:
|
| 43 |
+
current_query = user_input
|
| 44 |
+
|
| 45 |
+
cluster_info = None
|
| 46 |
+
max_turns = 3
|
| 47 |
+
|
| 48 |
+
# ํด๋ฌ์คํฐ ํ์ ๋ฃจํ
|
| 49 |
+
for turn in range(max_turns):
|
| 50 |
+
full_conversation.append(current_query)
|
| 51 |
+
status, data = get_user_cluster(current_query, conversation_context)
|
| 52 |
+
|
| 53 |
+
if status == "SUCCESS":
|
| 54 |
+
cluster_info = data
|
| 55 |
+
break
|
| 56 |
+
elif status == "RETRY_WITH_QUESTION":
|
| 57 |
+
question_to_user, updated_context = data
|
| 58 |
+
conversation_context = updated_context
|
| 59 |
+
|
| 60 |
+
# ์ง๋ฌธ๋ ์
๋ ฅ ์ธ์ด๋ก ๋ฒ์ญํด์ ์ฌ์ฉ์์๊ฒ ๋ณด์ฌ์ค
|
| 61 |
+
if input_lang != "ko":
|
| 62 |
+
try:
|
| 63 |
+
question_to_user = GoogleTranslator(source="ko", target=input_lang).translate(question_to_user)
|
| 64 |
+
except:
|
| 65 |
+
pass
|
| 66 |
+
|
| 67 |
+
# state ์
๋ฐ์ดํธ
|
| 68 |
+
state["conversation_context"] = conversation_context
|
| 69 |
+
state["full_conversation"] = full_conversation
|
| 70 |
+
return history + [[user_input, question_to_user]], state
|
| 71 |
+
|
| 72 |
+
elif status == "FAIL":
|
| 73 |
+
fail_msg = "์ต์ข
ํด๋ฌ์คํฐ ๋ถ์์ ์คํจํ์ต๋๋ค."
|
| 74 |
+
if input_lang != "ko":
|
| 75 |
+
try:
|
| 76 |
+
fail_msg = GoogleTranslator(source="ko", target=input_lang).translate(fail_msg)
|
| 77 |
+
except:
|
| 78 |
+
pass
|
| 79 |
+
return history + [[user_input, fail_msg]], state
|
| 80 |
+
|
| 81 |
+
# RAG ์คํ
|
| 82 |
+
if cluster_info:
|
| 83 |
+
cluster_id, cluster_profile = cluster_info
|
| 84 |
+
final_query_for_rag = " ".join(full_conversation)
|
| 85 |
+
region_keywords = extract_region_from_query(final_query_for_rag)
|
| 86 |
+
|
| 87 |
+
rag_query = f"{cluster_profile} ํน์ง์ ๊ฐ์ง ์ฌํ๊ฐ์ด '{final_query_for_rag}'์ ๊ฐ์ ์ฌํ์ ํ ๋ ๊ฐ๊ธฐ ์ข์ ๊ณณ"
|
| 88 |
+
final_answer_ko = get_rag_recommendation(rag_query, region_keywords)
|
| 89 |
+
|
| 90 |
+
# ์ต์ข
๋ต๋ณ๋ ์
๋ ฅ ์ธ์ด๋ก ๋ค์ ๋ฒ์ญ
|
| 91 |
+
final_answer = final_answer_ko
|
| 92 |
+
if input_lang != "ko":
|
| 93 |
+
try:
|
| 94 |
+
final_answer = GoogleTranslator(source="ko", target=input_lang).translate(final_answer_ko)
|
| 95 |
+
except:
|
| 96 |
+
final_answer = f"โ ๊ฒฐ๊ณผ ๋ฒ์ญ ์ค๋ฅ: {final_answer_ko}"
|
| 97 |
+
|
| 98 |
+
# state ์
๋ฐ์ดํธ
|
| 99 |
+
state["conversation_context"] = conversation_context
|
| 100 |
+
state["full_conversation"] = full_conversation
|
| 101 |
+
return history + [[user_input, final_answer]], state
|
| 102 |
+
|
| 103 |
+
else:
|
| 104 |
+
fail_msg = "์ถ์ฒ์ ์์ฑํ ์ ์์ต๋๋ค."
|
| 105 |
+
if input_lang != "ko":
|
| 106 |
+
try:
|
| 107 |
+
fail_msg = GoogleTranslator(source="ko", target=input_lang).translate(fail_msg)
|
| 108 |
+
except:
|
| 109 |
+
pass
|
| 110 |
+
return history + [[user_input, fail_msg]], state
|
| 111 |
+
|
| 112 |
+
# --- Gradio UI ์ ์ ---
|
| 113 |
+
with gr.Blocks() as demo:
|
| 114 |
+
gr.Markdown("## โ๏ธ ์ฌํ ์ถ์ฒ ์ฑ๋ด")
|
| 115 |
+
|
| 116 |
+
chatbot = gr.Chatbot(height=500)
|
| 117 |
+
msg = gr.Textbox(label="์ฌ์ฉ์ ์
๋ ฅ")
|
| 118 |
+
state = gr.State({"conversation_context": {}, "full_conversation": []})
|
| 119 |
+
|
| 120 |
+
def respond(message, chat_history, state):
|
| 121 |
+
response, new_state = chatbot_interface(message, chat_history, state)
|
| 122 |
+
return "", response, new_state
|
| 123 |
+
|
| 124 |
+
msg.submit(respond, [msg, chatbot, state], [msg, chatbot, state])
|
| 125 |
+
|
| 126 |
+
if __name__ == "__main__":
|
| 127 |
+
demo.launch(show_api=False, debug=True)
|
cluster_predictor.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# cluster_predictor.py
|
| 2 |
+
|
| 3 |
+
import joblib
|
| 4 |
+
import pandas as pd
|
| 5 |
+
from openai import OpenAI
|
| 6 |
+
import os
|
| 7 |
+
import json
|
| 8 |
+
from huggingface_hub import hf_hub_download
|
| 9 |
+
|
| 10 |
+
# Hugging Face dataset repo์์ prompt ํ์ผ ๋ก๋
|
| 11 |
+
PROMPT_PATH = hf_hub_download(
|
| 12 |
+
repo_id="Syngyeon/seoulalpha-data",
|
| 13 |
+
repo_type="dataset", # โ
๋ฐ๋์ dataset์ผ๋ก ์ง์
|
| 14 |
+
filename="data/prompt/custom_prompt_eng.txt"
|
| 15 |
+
)
|
| 16 |
+
FEWSHOT_PATH = hf_hub_download(
|
| 17 |
+
repo_id="Syngyeon/seoulalpha-data",
|
| 18 |
+
repo_type="dataset", # โ
๋ฐ๋์ dataset์ผ๋ก ์ง์
|
| 19 |
+
filename="data/prompt/custom_few_shot_learning_multi_language.txt"
|
| 20 |
+
)
|
| 21 |
+
# --- ์ด๊ธฐ ์ค์ ---
|
| 22 |
+
client = OpenAI(api_key=os.getenv("API_KEY"))
|
| 23 |
+
|
| 24 |
+
CLUSTER_PROFILES = {
|
| 25 |
+
0: "๋ฌธํ, ์ญ์ฌ, ์์ฐ ํ๋ฐฉ์ ์ฃผ๋ชฉ์ ์ผ๋ก ๊ฐ์์ ํ๊ตญ์ ์ฌ๋ฐฉ๋ฌธํ๋ ์ฌํ๊ฐ. ๊ธด ์ฒด๋ฅ ๊ธฐ๊ฐ ๋์ ์์ธ๊ณผ ์ฌ๋ฌ ์ง๋ฐฉ(๊ฒฝ๊ธฐ, ๊ฐ์, ๊ฒฝ์)์ ํจ๊ป ๋ฐฉ๋ฌธํ๋ฉฐ, ๋งค์ฐ ์๋ฐํ๊ฒ ์๋นํ๋ ๊ฒฝํฅ์ด ์์.",
|
| 26 |
+
1: "ํ๊ตญ์ ์ฒ์ ๋ฐฉ๋ฌธํ ์ฌํ๊ฐ. ์งง์ ๊ธฐ๊ฐ ๋์ ์์ธ์๋ง ๋จธ๋ฌด๋ฅด๋ฉฐ ์์๊ณผ ๋ฏธ์ ํ๋ฐฉ์ ๊ฐ์ฅ ํฐ ๊ด์ฌ์ ๋๊ณ ์ฌํํจ. ์๋ฐ๋น์ ๋น๊ต์ ๋์ ์์ฐ์ ์ฌ์ฉํจ.",
|
| 27 |
+
2: "ํ๊ตญ์ ์ฒ์ ๋ฐฉ๋ฌธํ ์ฌํ๊ฐ. ์งง์ ๊ธฐ๊ฐ ์์ธ์ ๋จธ๋ฌผ๋ฉฐ ์์, ์ผํ ๋ฑ ๋ชจ๋ ๋ถ์ผ์์ ์๋์ ์ธ ์๋น๋ ฅ์ ๋ณด์ฌ์ฃผ๋ ๋ญ์
๋ฆฌ ์ฌํ์ ์ฆ๊น.",
|
| 28 |
+
3: "์ผํ๊ณผ ๋ง์ง ํ๋ฐฉ์ ๋ชฉ์ ์ผ๋ก ์์ธ์ ์์ฃผ ์ฌ๋ฐฉ๋ฌธํ๋ ์ฌํ๊ฐ. ๋งค์ฐ ์งง์ ๊ธฐ๊ฐ ๋จธ๋ฌผ๋ฉฐ ์ฌํ ๋ชฉ์ ์ ์ง์ค์ ์ผ๋ก ๋ฌ์ฑํ๊ณ , ์๋น์ ์ง์ถ ๋น์ค์ด ๋งค์ฐ ๋์. ๋ฌธํ๋ ์์ฐ๋ณด๋ค ์ผํ๊ณผ ๋ฏธ์์ ๊ด์ฌ์ด ์ง์ค๋จ.",
|
| 29 |
+
4: "ํ๊ตญ ์ฌํ ๊ฒฝํ์ด ํ๋ถํ ์ฌ๋ฐฉ๋ฌธ๊ฐ. ์์ธ๋ฟ๋ง ์๋๋ผ ์ ๊ตญ์ ์ฌํํ๋ฉฐ, ํนํ ๋ค์ํ ์ง์ญ์ ์์์ ์ฆ๊ธฐ๋ ๋ฏธ์ ํ๋์ ๊ด์ฌ์ด ๋งค์ฐ ๋์.",
|
| 30 |
+
5: "ํ๊ตญ์ ์ฒ์ ๋ฐฉ๋ฌธํ๋ ์ฌํ๊ฐ. ๊ธด ๊ธฐ๊ฐ ๋์ ๋จธ๋ฌด๋ฅด๋ฉฐ ์์ธ์ ๋์ด ์ง๋ฐฉ, ํนํ ๊ฒฝ์๋ ์ง์ญ์ ์์ฐ ๊ฒฝ๊ด๊ณผ ๋ฌธํ ์ ์ฐ์ ๊น์ด ์๊ฒ ํํํ๋ ๊ฒ์ ๊ด์ฌ์ด ์๋์ ์ผ๋ก ๋์. ์์ฐ์ ๋น๊ต์ ์ ๊ฒ ์ฌ์ฉํจ.",
|
| 31 |
+
6: "ํ๊ตญ์ ์ฒ์ ๋ฐฉ๋ฌธํ๋ ์ฌํ๊ฐ. ๊ธด ๊ธฐ๊ฐ ๋์ ์ง๋ฐฉ, ํนํ ๊ฒฝ์๋๋ฅผ ์ฌํํ๋ฉฐ ํ๊ตญ์ ์์ฐ ๊ฒฝ๊ด๊ณผ ๋ฌธํ ์ ์ฐ์ ๋งค์ฐ ๋์ ๋ง์กฑ๋์ ๊น์ ๊ฐ๋ช
์ ๋๋. ์ฌ๋ฐฉ๋ฌธ ์ํฅ๋ ๋์ ์ด์์ ์ธ ํ๋ฐฉํ ์ฌํ๊ฐ."
|
| 32 |
+
}
|
| 33 |
+
# --- ๋ชจ๋ธ ๋ฐ ๋ฐ์ดํฐ ๋ก๋ ---
|
| 34 |
+
try:
|
| 35 |
+
preprocessor = joblib.load('./models/preprocessor.joblib')
|
| 36 |
+
pca = joblib.load('./models/pca.joblib')
|
| 37 |
+
kmeans = joblib.load('./models/kmeans.joblib')
|
| 38 |
+
imputation_base_data = pd.read_csv('./models/imputation_base_data.csv', encoding='utf-8-sig')
|
| 39 |
+
with open('./models/variable_weights.json', 'r', encoding='utf-8') as f:
|
| 40 |
+
VARIABLE_WEIGHTS = json.load(f)
|
| 41 |
+
except FileNotFoundError:
|
| 42 |
+
print("๋ชจ๋ธ ํ์ผ์ด ์์ต๋๋ค. train_model.py๋ฅผ ๋จผ์ ์คํํด์ฃผ์ธ์.")
|
| 43 |
+
preprocessor, pca, kmeans, imputation_base_data = None, None, None, None
|
| 44 |
+
|
| 45 |
+
# ๋ณ์ ์ ์
|
| 46 |
+
categorical_cols = ['country', 'gender', 'age', 'revisit_indicator', 'visit_local_indicator', 'planned_activity']
|
| 47 |
+
numerical_cols = ['stay_duration', 'accommodation_percent', 'food_percent', 'shopping_percent', 'food', 'landscape', 'heritage', 'language', 'safety', 'budget', 'accommodation', 'transport', 'navigation']
|
| 48 |
+
used_variables = categorical_cols + numerical_cols
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def query_llm_for_variables(user_query, use_prompt=True, use_fewshot=True):
|
| 52 |
+
prompt_parts = []
|
| 53 |
+
|
| 54 |
+
if use_prompt:
|
| 55 |
+
with open(PROMPT_PATH, "r", encoding="utf-8") as f:
|
| 56 |
+
custom_prompt = f.read()
|
| 57 |
+
prompt_parts.append(custom_prompt)
|
| 58 |
+
|
| 59 |
+
if use_fewshot:
|
| 60 |
+
with open(FEWSHOT_PATH, "r", encoding="utf-8") as f:
|
| 61 |
+
few_shot_examples = f.read()
|
| 62 |
+
prompt_parts.append(few_shot_examples)
|
| 63 |
+
|
| 64 |
+
full_prompt = "\n\n".join(prompt_parts)
|
| 65 |
+
|
| 66 |
+
messages = [
|
| 67 |
+
{"role": "system", "content": full_prompt},
|
| 68 |
+
{"role": "user", "content": user_query}
|
| 69 |
+
]
|
| 70 |
+
|
| 71 |
+
try:
|
| 72 |
+
response = client.chat.completions.create(
|
| 73 |
+
model="gpt-3.5-turbo",
|
| 74 |
+
messages=messages,
|
| 75 |
+
response_format={"type": "json_object"} # tsy ์ถ๊ฐ: JSON ์๋ต ํ์์ ๊ฐ์
|
| 76 |
+
)
|
| 77 |
+
content = response.choices[0].message.content.strip()
|
| 78 |
+
return json.loads(content)
|
| 79 |
+
except Exception as e:
|
| 80 |
+
print("[ํ์ฑ ์คํจ]", e)
|
| 81 |
+
return {}
|
| 82 |
+
|
| 83 |
+
def impute_with_user_subgroup(user_input_dict, df_base=imputation_base_data):
|
| 84 |
+
known_info = {k: v for k, v in user_input_dict.items() if v is not None}
|
| 85 |
+
filtered_df = df_base.copy()
|
| 86 |
+
for key, val in known_info.items():
|
| 87 |
+
if key in filtered_df.columns:
|
| 88 |
+
filtered_df = filtered_df[filtered_df[key].astype(str) == str(val)]
|
| 89 |
+
imputed = {}
|
| 90 |
+
for var in used_variables:
|
| 91 |
+
if user_input_dict.get(var) is not None:
|
| 92 |
+
imputed[var] = user_input_dict[var]
|
| 93 |
+
else:
|
| 94 |
+
if not filtered_df.empty:
|
| 95 |
+
if var in numerical_cols: imputed[var] = filtered_df[var].mean()
|
| 96 |
+
elif var in categorical_cols: imputed[var] = filtered_df[var].mode().iloc[0]
|
| 97 |
+
else:
|
| 98 |
+
if var in numerical_cols: imputed[var] = df_base[var].mean()
|
| 99 |
+
elif var in categorical_cols: imputed[var] = df_base[var].mode().iloc[0]
|
| 100 |
+
return imputed
|
| 101 |
+
|
| 102 |
+
def predict_cluster_from_query(variable_dict: dict):
|
| 103 |
+
# ์ด ํจ์๋ ๋ ์ด์ LLM์ ํธ์ถํ์ง ์๊ณ , ์ฃผ์ด์ง ์ ๋ณด๋ก ์์ธก๋ง ์ํ
|
| 104 |
+
if not variable_dict: return None
|
| 105 |
+
|
| 106 |
+
completed_input = impute_with_user_subgroup(variable_dict)
|
| 107 |
+
df = pd.DataFrame([completed_input])
|
| 108 |
+
|
| 109 |
+
for col in categorical_cols:
|
| 110 |
+
if col in df.columns: df[col] = df[col].astype(str)
|
| 111 |
+
for col in numerical_cols:
|
| 112 |
+
if col in df.columns: df[col] = pd.to_numeric(df[col], errors='coerce')
|
| 113 |
+
|
| 114 |
+
try:
|
| 115 |
+
X_processed = preprocessor.transform(df)
|
| 116 |
+
X_pca = pca.transform(X_processed)
|
| 117 |
+
return kmeans.predict(X_pca)[0]
|
| 118 |
+
except Exception as e:
|
| 119 |
+
print(f"[ํด๋ฌ์คํฐ ์์ธก ์คํจ] {e}")
|
| 120 |
+
return None
|
| 121 |
+
|
| 122 |
+
# ==================== ์ ๊ท ์ถ๊ฐ: ํฌํผ ํจ์ ====================
|
| 123 |
+
|
| 124 |
+
def _calculate_info_score(extracted_vars):
|
| 125 |
+
"""์ถ์ถ๋ ๋ณ์๋ค์ ๊ฐ์ค์น ํฉ์ผ๋ก ์ ๋ณด ์ถฉ๋ถ๋ ์ ์๋ฅผ ๊ณ์ฐํฉ๋๋ค."""
|
| 126 |
+
if not VARIABLE_WEIGHTS: return 0.0
|
| 127 |
+
current_score = sum(VARIABLE_WEIGHTS.get(var, 0) for var, value in extracted_vars.items() if value is not None)
|
| 128 |
+
print(f"์ ๋ณด ์ถฉ๋ถ๋ ์ ์: {current_score:.4f}")
|
| 129 |
+
return current_score
|
| 130 |
+
|
| 131 |
+
def _generate_clarifying_question(user_query, context):
|
| 132 |
+
variable_map = {
|
| 133 |
+
'revisit_indicator': '์ด๋ฒ์ด ํ๊ตญ ์ฒซ ๋ฐฉ๋ฌธ์ธ์ง, ํน์ ์ด์ ์ ํ๊ตญ์ ๋ฐฉ๋ฌธํ ์ ์ด ์๋์ง',
|
| 134 |
+
'visit_local_indicator': '์๋๊ถ(์์ธ/๊ฒฝ๊ธฐ/์ธ์ฒ) ์ธ ๋ค๋ฅธ ์ง์ญ์ ๋ฐฉ๋ฌธํ ๊ณํ์ด ์๋์ง',
|
| 135 |
+
'stay_duration': 'ํ๊ตญ ์ฌํ ๊ธฐ๊ฐ',
|
| 136 |
+
'planned_activity': 'ํ๊ตญ ์ฌํ์ ํ๊ธฐ์ํด ๊ณํํ ํ๋'
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
missing_vars = []
|
| 140 |
+
if VARIABLE_WEIGHTS:
|
| 141 |
+
sorted_vars = sorted(VARIABLE_WEIGHTS.keys(), key=lambda k: VARIABLE_WEIGHTS[k], reverse=True)
|
| 142 |
+
for var in sorted_vars:
|
| 143 |
+
if context.get(var) is None and var in variable_map:
|
| 144 |
+
missing_vars.append(variable_map[var])
|
| 145 |
+
|
| 146 |
+
if not missing_vars:
|
| 147 |
+
return "์ฌํ์ ๋ํด ์กฐ๊ธ๋ง ๋ ์์ธํ ๋ง์ํด์ฃผ์๊ฒ ์ด์?"
|
| 148 |
+
|
| 149 |
+
question_prompt = f"""๋น์ ์ ์น์ ํ ์ฌํ ํ๋๋์
๋๋ค.
|
| 150 |
+
์ฌ์ฉ์๊ฐ ์๋์ ๊ฐ์ด ์ง๋ฌธํ์ต๋๋ค.
|
| 151 |
+
์ฌ์ฉ์ ์ง๋ฌธ: "{user_query}"
|
| 152 |
+
|
| 153 |
+
์ฌ์ฉ์ ๋ง์ถค ์ถ์ฒ์ ์ํด '{', '.join(missing_vars[:2])}' ์ ๋ณด๊ฐ ํ์ํฉ๋๋ค.
|
| 154 |
+
์ฌ์ฉ์์ ์ง๋ฌธ ๋งฅ๋ฝ์ ๋ง์ถฐ ์์ฐ์ค๋ฝ๊ฒ ์ง๋ฌธ์ ํ ๋ฌธ์ฅ์ผ๋ก ๋ง๋ค์ด์ฃผ์ธ์."""
|
| 155 |
+
try:
|
| 156 |
+
response = client.chat.completions.create(model="gpt-3.5-turbo", messages=[{"role": "system", "content": question_prompt}])
|
| 157 |
+
return response.choices[0].message.content
|
| 158 |
+
except Exception:
|
| 159 |
+
return f"ํน์ ๊ณํ ์ค์ธ {missing_vars[0]}์ ๋ํด ์กฐ๊ธ ๋ ์๋ ค์ฃผ์ค ์ ์๋์?"
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
# --- ๋ํ ์คํ ํจ์ (์ฌ์ค๊ณ) ---
|
| 164 |
+
def get_user_cluster(user_query: str, previous_context: dict = None):
|
| 165 |
+
if preprocessor is None or pca is None or kmeans is None or imputation_base_data.empty:
|
| 166 |
+
return None, None
|
| 167 |
+
|
| 168 |
+
#if not all([preprocessor, pca, kmeans, imputation_base_data, VARIABLE_WEIGHTS]):
|
| 169 |
+
# return "FAIL", "ํ์ ๋ชจ๋ธ/๋ฐ์ดํฐ ํ์ผ์ด ๋ก๋๋์ง ์์์ต๋๋ค."
|
| 170 |
+
|
| 171 |
+
newly_extracted_vars = query_llm_for_variables(user_query)
|
| 172 |
+
current_context = previous_context.copy() if previous_context else {}
|
| 173 |
+
current_context.update({k: v for k, v in newly_extracted_vars.items() if v is not None})
|
| 174 |
+
|
| 175 |
+
score = _calculate_info_score(current_context)
|
| 176 |
+
|
| 177 |
+
if score > 0.50:
|
| 178 |
+
#print("โ
์ ๋ณด๊ฐ ์ถฉ๋ถํ์ฌ ํด๋ฌ์คํฐ๋ง์ ์งํํฉ๋๋ค.")
|
| 179 |
+
cluster_label = predict_cluster_from_query(current_context)
|
| 180 |
+
if cluster_label is not None:
|
| 181 |
+
profile = CLUSTER_PROFILES.get(cluster_label, "์ ์๋์ง ์์ ํด๋ฌ์คํฐ์
๋๋ค.")
|
| 182 |
+
return "SUCCESS", (cluster_label, profile)
|
| 183 |
+
else:
|
| 184 |
+
return "FAIL", "ํด๋ฌ์คํฐ ์์ธก์ ์คํจํ์ต๋๋ค."
|
| 185 |
+
else:
|
| 186 |
+
#print("โ ๏ธ ์ ๋ณด๊ฐ ๋ถ์ถฉ๋ถํ์ฌ ์ฌ์ฉ์์๊ฒ ์ฌ์ง์ํฉ๋๋ค.")
|
| 187 |
+
question = _generate_clarifying_question(user_query, current_context)
|
| 188 |
+
return "RETRY_WITH_QUESTION", (question, current_context)
|
models/feature_importance_ranking.csv
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
๏ปฟfeature,importance
|
| 2 |
+
cat__gender_2,0.16612719377732216
|
| 3 |
+
cat__revisit_indicator_0,0.16455068671975487
|
| 4 |
+
cat__gender_1,0.1482217051794164
|
| 5 |
+
cat__revisit_indicator_1,0.13597727601444679
|
| 6 |
+
cat__visit_local_indicator_0,0.13319687700894334
|
| 7 |
+
cat__visit_local_indicator_1,0.10697746487224133
|
| 8 |
+
cat__visit_local_indicator_2,0.020241612346529863
|
| 9 |
+
num__stay_duration,0.014849368057557598
|
| 10 |
+
num__food_percent,0.014630130868015519
|
| 11 |
+
num__shopping_percent,0.009280205750283798
|
| 12 |
+
num__accommodation_percent,0.009180690732971732
|
| 13 |
+
cat__planned_activity_2.0,0.008696580924116888
|
| 14 |
+
cat__country_1,0.0070618551302561675
|
| 15 |
+
cat__planned_activity_3.0,0.005779261578544953
|
| 16 |
+
cat__age_2,0.004328397535810549
|
| 17 |
+
cat__planned_activity_4.0,0.003544786166398055
|
| 18 |
+
cat__planned_activity_99.0,0.003167551351550833
|
| 19 |
+
cat__country_2,0.0029464161975619727
|
| 20 |
+
num__language,0.002927291909464459
|
| 21 |
+
num__budget,0.0027502516728713754
|
| 22 |
+
num__safety,0.0025716066643480325
|
| 23 |
+
num__navigation,0.0025153003716470498
|
| 24 |
+
num__transport,0.002371096987708634
|
| 25 |
+
num__food,0.00230589888998026
|
| 26 |
+
num__accommodation,0.0022564692709160917
|
| 27 |
+
num__landscape,0.0021469948117270295
|
| 28 |
+
num__heritage,0.001992909406583649
|
| 29 |
+
cat__country_5,0.0018994328238773261
|
| 30 |
+
cat__age_3,0.0014320177855646557
|
| 31 |
+
cat__country_10,0.0014213292787885526
|
| 32 |
+
cat__age_5,0.0012658681149728401
|
| 33 |
+
cat__age_4,0.0011676351945606124
|
| 34 |
+
cat__country_8,0.0009806737230181408
|
| 35 |
+
cat__planned_activity_6.0,0.0009291243639821009
|
| 36 |
+
cat__planned_activity_1.0,0.0009189380963146095
|
| 37 |
+
cat__country_4,0.0009131123524501497
|
| 38 |
+
cat__country_3,0.0008415104761365984
|
| 39 |
+
cat__country_9,0.0006391833106764304
|
| 40 |
+
cat__country_6,0.0005992862929722725
|
| 41 |
+
cat__country_15,0.0005961254268885925
|
| 42 |
+
cat__planned_activity_7.0,0.0005203510135881252
|
| 43 |
+
cat__age_1,0.0004786493334730666
|
| 44 |
+
cat__country_7,0.00045682833180356916
|
| 45 |
+
cat__age_6,0.0004286939505569954
|
| 46 |
+
cat__country_11,0.0004238714395302936
|
| 47 |
+
cat__country_13,0.0004145802913805203
|
| 48 |
+
cat__country_14,0.0003922145285860367
|
| 49 |
+
cat__country_99,0.0003900235538848013
|
| 50 |
+
cat__country_19,0.00036882244028804445
|
| 51 |
+
cat__country_12,0.00036315525951813906
|
| 52 |
+
cat__country_18,0.00036298882952069363
|
| 53 |
+
cat__planned_activity_8.0,0.0003433572001072147
|
| 54 |
+
cat__country_16,0.0003025359192961062
|
| 55 |
+
cat__planned_activity_5.0,0.00025347994754935764
|
| 56 |
+
cat__country_17,0.0002398183809611425
|
| 57 |
+
cat__country_20,6.0512142783746834e-05
|
models/imputation_base_data.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
models/kmeans.joblib
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:71afae6ee2376a18af9bda0f6b7a3e8263458cf7e8747ed1a49b89f5e3834ff9
|
| 3 |
+
size 78363
|
models/pca.joblib
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:78eae182e98318f55b029c52086e88bfce0dad6fa28a6cd426d8b140925a4a39
|
| 3 |
+
size 2815
|
models/preprocessor.joblib
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:29424baa251ab960783524bdd697cbf145a5db0b1a6eb3313944a5acf99fd766
|
| 3 |
+
size 7330
|
models/variable_weights.json
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"revisit_indicator":0.4589085007,
|
| 3 |
+
"visit_local_indicator":0.3976571565,
|
| 4 |
+
"planned_activity":0.0368824736,
|
| 5 |
+
"stay_duration":0.0226750987,
|
| 6 |
+
"food_percent":0.0223403219,
|
| 7 |
+
"shopping_percent":0.0141709453,
|
| 8 |
+
"accommodation_percent":0.0140189851,
|
| 9 |
+
"language":0.0044699972,
|
| 10 |
+
"budget":0.0041996554,
|
| 11 |
+
"safety":0.0039268631,
|
| 12 |
+
"navigation":0.0038408829,
|
| 13 |
+
"transport":0.0036206833,
|
| 14 |
+
"food":0.0035211253,
|
| 15 |
+
"accommodation":0.0034456459,
|
| 16 |
+
"landscape":0.0032784775,
|
| 17 |
+
"heritage":0.0030431879
|
| 18 |
+
}
|
rag_retriever.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# rag_retriever.py
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
import faiss
|
| 6 |
+
import numpy as np
|
| 7 |
+
from sentence_transformers import SentenceTransformer
|
| 8 |
+
from openai import OpenAI
|
| 9 |
+
from huggingface_hub import hf_hub_download
|
| 10 |
+
|
| 11 |
+
# --- ์ค์ ---
|
| 12 |
+
MODEL_NAME = 'jhgan/ko-sbert-nli'
|
| 13 |
+
LLM_MODEL_NAME = 'gpt-3.5-turbo'
|
| 14 |
+
DATA_REPO = "Syngyeon/seoulalpha-data"
|
| 15 |
+
TOP_K = 10
|
| 16 |
+
|
| 17 |
+
# OpenAI ํด๋ผ์ด์ธํธ ์ด๊ธฐํ
|
| 18 |
+
client = OpenAI(api_key=os.getenv("API_KEY"))
|
| 19 |
+
|
| 20 |
+
# --- ๋ฆฌ์์ค ๋ก๋ฉ ---
|
| 21 |
+
def _load_resources():
|
| 22 |
+
"""๋ชจ๋ ๋ก๋ฉ ์ ๊ฒ์์ ํ์ํ ๋ฆฌ์์ค๋ฅผ ๋ฏธ๋ฆฌ ๋ถ๋ฌ์ต๋๋ค."""
|
| 23 |
+
try:
|
| 24 |
+
print("1. Hugging Face Hub์์ RAG ๋ฆฌ์์ค๋ฅผ ๋ค์ด๋ก๋ํฉ๋๋ค...")
|
| 25 |
+
|
| 26 |
+
# HF repo์์ ํ์ผ ๋ค์ด๋ก๋
|
| 27 |
+
index_path = hf_hub_download(repo_id=DATA_REPO, repo_type="dataset", filename="data/faiss/faiss_merged_output/merged.index")
|
| 28 |
+
metadata_path = hf_hub_download(repo_id=DATA_REPO, repo_type="dataset", filename="data/faiss/faiss_merged_output/merged_metadata.jsonl")
|
| 29 |
+
|
| 30 |
+
# ์๋ฒ ๋ฉ ๋ชจ๋ธ ๋ก๋
|
| 31 |
+
model = SentenceTransformer(MODEL_NAME)
|
| 32 |
+
|
| 33 |
+
# FAISS index ๋ก๋
|
| 34 |
+
index = faiss.read_index(index_path)
|
| 35 |
+
|
| 36 |
+
# ๋ฉํ๋ฐ์ดํฐ ๋ก๋
|
| 37 |
+
metadata_map = {}
|
| 38 |
+
with open(metadata_path, 'r', encoding='utf-8') as f:
|
| 39 |
+
for line in f:
|
| 40 |
+
meta = json.loads(line)
|
| 41 |
+
metadata_map[meta['vector_id']] = meta
|
| 42 |
+
|
| 43 |
+
print("RAG ๋ฆฌ์์ค ๋ก๋ฉ ์๋ฃ!")
|
| 44 |
+
return model, index, metadata_map
|
| 45 |
+
except Exception as e:
|
| 46 |
+
print(f"RAG ๋ฆฌ์์ค ๋ก๋ฉ์ ์คํจํ์ต๋๋ค: {e}")
|
| 47 |
+
return None, None, None
|
| 48 |
+
|
| 49 |
+
# ๋ชจ๋์ด ์ํฌํธ๋ ๋ ๋ฆฌ์์ค๋ฅผ ํ ๋ฒ๋ง ๋ก๋ํฉ๋๋ค.
|
| 50 |
+
embedding_model, faiss_index, meta_map = _load_resources()
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def _retrieve_places(query, k):
|
| 54 |
+
"""๋ด๋ถ ํจ์: ์ฟผ๋ฆฌ๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ์ ์ฌํ ์ฅ์๋ฅผ ๊ฒ์ํฉ๋๋ค."""
|
| 55 |
+
query_vector = embedding_model.encode([query])
|
| 56 |
+
distances, ids = faiss_index.search(query_vector.astype('float32'), k)
|
| 57 |
+
|
| 58 |
+
results = []
|
| 59 |
+
for vector_id in ids[0]:
|
| 60 |
+
if vector_id in meta_map:
|
| 61 |
+
results.append(meta_map[vector_id])
|
| 62 |
+
return results
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def _generate_answer_with_llm(query, retrieved_places):
|
| 66 |
+
"""๋ด๋ถ ํจ์: ๊ฒ์๋ ์ ๋ณด๋ฅผ ๋ฐํ์ผ๋ก LLM ๋ต๋ณ์ ์์ฑํฉ๋๋ค."""
|
| 67 |
+
context = ""
|
| 68 |
+
for i, place in enumerate(retrieved_places[:5]): # ์์ 5๊ฐ ์ ๋ณด๋ง ์ฌ์ฉ
|
| 69 |
+
context += f"--- ์ฅ์ ์ ๋ณด {i+1} ---\n"
|
| 70 |
+
context += f"์ด๋ฆ: {place.get('name', '์ ๋ณด ์์')}\n"
|
| 71 |
+
context += f"์ฃผ์: {place.get('address', '์ ๋ณด ์์')}\n"
|
| 72 |
+
context += f"AI ์์ฝ: {place.get('ai_summary', '์ ๋ณด ์์')}\n"
|
| 73 |
+
processed_sentences = place.get('processed_sentences', [])
|
| 74 |
+
context += "์ฃผ์ ํน์ง ๋ฐ ํ๊ธฐ:\n"
|
| 75 |
+
for sentence in processed_sentences:
|
| 76 |
+
context += f"- {sentence}\n"
|
| 77 |
+
context += "\n"
|
| 78 |
+
|
| 79 |
+
system_prompt = "๋น์ ์ ์ฌ์ฉ์์ ์ง๋ฌธ์ ๊ฐ์ฅ ์ ํฉํ ์ฅ์๋ฅผ ์ถ์ฒํด์ฃผ๋ ์ ์ฉํ ์ด์์คํดํธ์
๋๋ค."
|
| 80 |
+
user_prompt = f"""
|
| 81 |
+
์๋ '์ฅ์ ์ ๋ณด'๋ง์ ๋ฐํ์ผ๋ก ์ฌ์ฉ์์ ์ง๋ฌธ์ ๋ํ ๋ต๋ณ์ ์์ฑํด ์ฃผ์ธ์.
|
| 82 |
+
|
| 83 |
+
[์ง์์ฌํญ]
|
| 84 |
+
1. ๊ฒ์๋ ์ฅ์ ์ค์์ ์ง๋ฌธ๊ณผ ๊ฐ์ฅ ๊ด๋ จ์ฑ์ด ๋์ 2~3๊ณณ์ ์ถ์ฒํด ์ฃผ์ธ์.
|
| 85 |
+
2. ๊ฐ ์ฅ์๋ฅผ ์ถ์ฒํ ๋, ๋ฐ๋์ '์ด๋ฆ'๊ณผ '์ฃผ์'๋ฅผ ๋ช
ํํ๊ฒ ํจ๊ป ํ์ํด์ฃผ์ธ์.
|
| 86 |
+
3. ๊ฐ ์ฅ์๋ฅผ ์ถ์ฒํ๋ ์ด์ ๋ฅผ 'AI ์์ฝ'๊ณผ '์ฃผ์ ํน์ง ๋ฐ ํ๊ธฐ'๋ฅผ ๊ทผ๊ฑฐ๋ก ๊ตฌ์ฒด์ ์ผ๋ก ์ค๋ช
ํด ์ฃผ์ธ์.
|
| 87 |
+
4. 'processed_sentences'์ ์๋ ์ค์ ํ๊ธฐ๋ฅผ ์ธ์ฉํ์ฌ ๋ต๋ณํ๋ฉด ์ ๋ขฐ๋๋ฅผ ๋์ผ ์ ์์ต๋๋ค.
|
| 88 |
+
5. ์น์ ํ๊ณ ์์ฐ์ค๋ฌ์ด ๋งํฌ๋ก ๋ต๋ณํด ์ฃผ์ธ์.
|
| 89 |
+
|
| 90 |
+
--- ์ฅ์ ์ ๋ณด ---
|
| 91 |
+
{context}
|
| 92 |
+
--- ์ฌ์ฉ์์ ์ง๋ฌธ ---
|
| 93 |
+
{query}
|
| 94 |
+
"""
|
| 95 |
+
try:
|
| 96 |
+
response = client.chat.completions.create(
|
| 97 |
+
model=LLM_MODEL_NAME,
|
| 98 |
+
messages=[
|
| 99 |
+
{"role": "system", "content": system_prompt},
|
| 100 |
+
{"role": "user", "content": user_prompt}
|
| 101 |
+
],
|
| 102 |
+
temperature=0.7,
|
| 103 |
+
)
|
| 104 |
+
return response.choices[0].message.content
|
| 105 |
+
except Exception as e:
|
| 106 |
+
return f"LLM ๋ต๋ณ ์์ฑ ์ค ์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค: {e}"
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
# --- ๋ํ ์คํ ํจ์ ---
|
| 110 |
+
def get_rag_recommendation(search_query, region_keywords):
|
| 111 |
+
"""
|
| 112 |
+
๊ฒ์ ์ฟผ๋ฆฌ์ ์ง์ญ ํค์๋๋ฅผ ๋ฐ์ RAG ์์คํ
์ ํตํด ์ต์ข
์ถ์ฒ ๋ต๋ณ์ ๋ฐํํฉ๋๋ค.
|
| 113 |
+
"""
|
| 114 |
+
if not all([embedding_model, faiss_index, meta_map]):
|
| 115 |
+
return "RAG ์์คํ
์ด ์ค๋น๋์ง ์์ ์ถ์ฒ์ ์์ฑํ ์ ์์ต๋๋ค."
|
| 116 |
+
|
| 117 |
+
# 1. ์ฅ์ ๊ฒ์
|
| 118 |
+
print("\n[RAG] ์๋ฏธ์ ์ผ๋ก ์ ์ฌํ ์ฅ์๋ฅผ ๊ฒ์ํฉ๋๋ค...")
|
| 119 |
+
top_places = _retrieve_places(search_query, k=100)
|
| 120 |
+
|
| 121 |
+
if not top_places:
|
| 122 |
+
return "๊ด๋ จ๋ ์ฅ์๋ฅผ ์ฐพ์ง ๋ชปํ์ต๋๋ค."
|
| 123 |
+
|
| 124 |
+
# 2. ์ง์ญ ํํฐ๋ง
|
| 125 |
+
if region_keywords:
|
| 126 |
+
print(f"[RAG] ์ฃผ์ ํํฐ๋ง (ํค์๋: {region_keywords})...")
|
| 127 |
+
filtered_places = []
|
| 128 |
+
for place in top_places:
|
| 129 |
+
address = place.get('address', '')
|
| 130 |
+
if any(keyword in address for keyword in region_keywords):
|
| 131 |
+
filtered_places.append(place)
|
| 132 |
+
if len(filtered_places) >= 10:
|
| 133 |
+
break
|
| 134 |
+
print(f"[RAG] ํํฐ๋ง ํ ๋จ์ ์ฅ์: {[p.get('name') for p in filtered_places]}")
|
| 135 |
+
else:
|
| 136 |
+
print("[RAG] ์ง์ญ ํค์๋๊ฐ ์์ด ํํฐ๋ง์ ๊ฑด๋๋๋๋ค.")
|
| 137 |
+
filtered_places = top_places
|
| 138 |
+
|
| 139 |
+
if not filtered_places:
|
| 140 |
+
return "์์ฒญํ์ ์ง์ญ์ ๋ง๋ ์ฅ์๋ฅผ ์ฐพ์ง ๋ชปํ์ต๋๋ค."
|
| 141 |
+
|
| 142 |
+
# 3. LLM์ผ๋ก ๋ต๋ณ ์์ฑ
|
| 143 |
+
print("[RAG] ํํฐ๋ง๋ ์ ๋ณด๋ฅผ ๋ฐํ์ผ๋ก ์ต์ข
๋ต๋ณ์ ์์ฑํฉ๋๋ค...")
|
| 144 |
+
final_answer = _generate_answer_with_llm(search_query, filtered_places)
|
| 145 |
+
|
| 146 |
+
return final_answer
|
region_extractor.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import faiss
|
| 4 |
+
import numpy as np
|
| 5 |
+
from sentence_transformers import SentenceTransformer
|
| 6 |
+
from huggingface_hub import hf_hub_download
|
| 7 |
+
from openai import OpenAI # ๐น ์ถ๊ฐ
|
| 8 |
+
|
| 9 |
+
DATA_REPO = "Syngyeon/seoulalpha-data"
|
| 10 |
+
MODEL_NAME = "jhgan/ko-sbert-nli"
|
| 11 |
+
|
| 12 |
+
# OpenAI ํด๋ผ์ด์ธํธ ์ด๊ธฐํ
|
| 13 |
+
client = OpenAI(api_key=os.getenv("API_KEY")) # ๐น ์ถ๊ฐ
|
| 14 |
+
|
| 15 |
+
# ๋ก๋
|
| 16 |
+
def _load_region_index():
|
| 17 |
+
try:
|
| 18 |
+
index_path = hf_hub_download(
|
| 19 |
+
repo_id=DATA_REPO, repo_type="dataset",
|
| 20 |
+
filename="data/faiss/region_db/faiss_region_semantic.index"
|
| 21 |
+
)
|
| 22 |
+
metadata_path = hf_hub_download(
|
| 23 |
+
repo_id=DATA_REPO, repo_type="dataset",
|
| 24 |
+
filename="data/faiss/region_db/metadata_region_semantic.jsonl"
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
index = faiss.read_index(index_path)
|
| 28 |
+
model = SentenceTransformer(MODEL_NAME)
|
| 29 |
+
|
| 30 |
+
metadata_map = {}
|
| 31 |
+
with open(metadata_path, "r", encoding="utf-8") as f:
|
| 32 |
+
for line in f:
|
| 33 |
+
meta = json.loads(line)
|
| 34 |
+
metadata_map[meta["vector_id"]] = meta
|
| 35 |
+
|
| 36 |
+
print("[RegionDB] ๋ก๋ฉ ์๋ฃ")
|
| 37 |
+
return model, index, metadata_map
|
| 38 |
+
except Exception as e:
|
| 39 |
+
print("[RegionDB] ๋ก๋ฉ ์คํจ:", e)
|
| 40 |
+
return None, None, None
|
| 41 |
+
|
| 42 |
+
region_model, region_index, region_meta = _load_region_index()
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def extract_region_semantic(user_query, top_k=5):
|
| 46 |
+
"""FAISS ๊ธฐ๋ฐ ์ง์ญ ํ๋ณด ์ถ์ถ"""
|
| 47 |
+
if not all([region_model, region_index, region_meta]):
|
| 48 |
+
return []
|
| 49 |
+
|
| 50 |
+
query_vec = region_model.encode([user_query]).astype("float32")
|
| 51 |
+
distances, ids = region_index.search(query_vec, top_k)
|
| 52 |
+
|
| 53 |
+
results = []
|
| 54 |
+
for i, vid in enumerate(ids[0]):
|
| 55 |
+
if vid in region_meta:
|
| 56 |
+
results.append(region_meta[vid]["region_name"])
|
| 57 |
+
return results
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def extract_region_from_query(user_query):
|
| 61 |
+
"""
|
| 62 |
+
์ฌ์ฉ์ ์ง๋ฌธ์์ LLM์ ์ฌ์ฉํด ์ง์ญ๋ช
ํค์๋ ๋ฆฌ์คํธ๋ฅผ ์ถ์ถํฉ๋๋ค.
|
| 63 |
+
"""
|
| 64 |
+
print("[LLM] ์ฌ์ฉ์ ์ฟผ๋ฆฌ์์ ์ง์ญ๋ช
ํค์๋๋ฅผ ์ถ์ถํฉ๋๋ค...")
|
| 65 |
+
|
| 66 |
+
system_prompt = """
|
| 67 |
+
๋น์ ์ ์ฌ์ฉ์์ ์ฌํ ๊ด๋ จ ์ง๋ฌธ์์ '๋ํ๋ฏผ๊ตญ ํ์ ๊ตฌ์ญ' ํค์๋๋ฅผ ์ถ์ถํ๋ AI ์ด์์คํดํธ์
๋๋ค.
|
| 68 |
+
์ฌ์ฉ์์ ์ง๋ฌธ์ ๋ถ์ํ์ฌ, ์ฃผ์ ํํฐ๋ง์ ์ฌ์ฉํ ์ ์๋ ํค์๋ ๋ชฉ๋ก์ JSON ํ์์ผ๋ก ๋ฐํํด ์ฃผ์ธ์.
|
| 69 |
+
๊ฒฐ๊ณผ๋ ๋ฐ๋์ {"regions": ["ํค์๋1", "ํค์๋2", ...]} ํํ์ฌ์ผ ํฉ๋๋ค.
|
| 70 |
+
|
| 71 |
+
- "์ ๋ผ๋"๋ "์ ๋ถ", "์ ๋จ", "๊ด์ฃผ"๋ก ํด์ํฉ๋๋ค.
|
| 72 |
+
- "๊ฒฝ์๋"๋ "๊ฒฝ๋ถ", "๊ฒฝ๋จ", "๋ถ์ฐ", "๋๊ตฌ", "์ธ์ฐ"์ผ๋ก ํด์ํฉ๋๋ค.
|
| 73 |
+
- "์ถฉ์ฒญ๋"๋ "์ถฉ๋ถ", "์ถฉ๋จ", "๋์ ", "์ธ์ข
"์ผ๋ก ํด์ํฉ๋๋ค.
|
| 74 |
+
- "์์ธ ๊ทผ๊ต"๋ "๊ฒฝ๊ธฐ", "์ธ์ฒ"์ผ๋ก ํด์ํฉ๋๋ค.
|
| 75 |
+
- ์ธ๊ธ๋ ์ง์ญ์ด ์์ผ๋ฉด ๋น ๋ฆฌ์คํธ []๋ฅผ ๋ฐํํฉ๋๋ค.
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
messages = [
|
| 79 |
+
{"role": "system", "content": system_prompt},
|
| 80 |
+
{"role": "user", "content": user_query}
|
| 81 |
+
]
|
| 82 |
+
|
| 83 |
+
try:
|
| 84 |
+
response = client.chat.completions.create(
|
| 85 |
+
model="gpt-3.5-turbo",
|
| 86 |
+
messages=messages,
|
| 87 |
+
response_format={"type": "json_object"}
|
| 88 |
+
)
|
| 89 |
+
result = json.loads(response.choices[0].message.content)
|
| 90 |
+
|
| 91 |
+
if 'regions' in result and isinstance(result['regions'], list):
|
| 92 |
+
return result['regions']
|
| 93 |
+
else:
|
| 94 |
+
return []
|
| 95 |
+
except Exception as e:
|
| 96 |
+
print(f"[LLM] ์ง์ญ๋ช
์ถ์ถ ์ค ์ค๋ฅ ๋ฐ์: {e}")
|
| 97 |
+
return []
|
requirements.txt
ADDED
|
Binary file (322 Bytes). View file
|
|
|