leedami's picture
Deploy from Team Script
41cc6f7 verified
import os
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from langchain_community.vectorstores import Chroma
from langchain_huggingface import HuggingFaceEmbeddings
import umap.umap_ as umap
# ==========================================
# [์Šคํฌ๋ฆฝํŠธ ์„ค๋ช…]
# ๋ฒกํ„ฐ DB ๊ณ ๊ธ‰ ์‹œ๊ฐํ™” ๋„๊ตฌ (3D & Interactive)
# 1. ChromaDB์— ์ €์žฅ๋œ ๋ชจ๋“  ์ƒํ’ˆ ๋ฒกํ„ฐ๋ฅผ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค.
# 2. UMAP ์•Œ๊ณ ๋ฆฌ์ฆ˜์œผ๋กœ 768์ฐจ์› ๋ฒกํ„ฐ๋ฅผ 3์ฐจ์›์œผ๋กœ ์ถ•์†Œํ•ฉ๋‹ˆ๋‹ค.
# 3. Plotly๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ํšŒ์ „/์คŒ์ด ๊ฐ€๋Šฅํ•œ 3D ์‚ฐ์ ๋„ HTML์„ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
# 4. (์„ ํƒ) ์งˆ๋ฌธ์„ ์ž…๋ ฅํ•˜๋ฉด, ์งˆ๋ฌธ ๋ฒกํ„ฐ์˜ ์œ„์น˜๋„ ํ•จ๊ป˜ ํ‘œ์‹œํ•ฉ๋‹ˆ๋‹ค.
# ==========================================
# --- ์„ค์ • ---
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
CHROMA_DB_PATH = os.path.join(BASE_DIR, '..', '..', 'data', 'chroma_db')
EMBEDDING_MODEL_PATH = os.path.join(BASE_DIR, '..', '..', 'models', 'snowflake-finetuned-hard')
OUTPUT_HTML_PATH = os.path.join(BASE_DIR, '..', '..', 'embedding_visualization_3d.html')
def visualize_3d(query_text=None):
print("--- 3D ์ž„๋ฒ ๋”ฉ ์‹œ๊ฐํ™” ์‹œ์ž‘ ---")
# 1. ๋ชจ๋ธ ๋ฐ DB ๋กœ๋“œ
print(f"๋ชจ๋ธ ๋กœ๋“œ ์ค‘: {EMBEDDING_MODEL_PATH}")
embeddings = HuggingFaceEmbeddings(
model_name=EMBEDDING_MODEL_PATH,
model_kwargs={'device': 'cuda'},
encode_kwargs={'normalize_embeddings': True}
)
vectorstore = Chroma(
persist_directory=CHROMA_DB_PATH,
embedding_function=embeddings
)
# 2. ๋ฐ์ดํ„ฐ ์ถ”์ถœ
print("DB์—์„œ ๋ฐ์ดํ„ฐ ์ถ”์ถœ ์ค‘...")
data = vectorstore.get(include=['embeddings', 'metadatas', 'documents'])
if data['embeddings'] is None or len(data['embeddings']) == 0:
print("๋ฐ์ดํ„ฐ๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค.")
return
vectors = np.array(data['embeddings'])
metadatas = data['metadatas']
documents = data['documents']
# ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ ์ •๋ฆฌ (DataFrame ์ƒ์„ฑ์šฉ)
df_data = []
for i, meta in enumerate(metadatas):
df_data.append({
'product_name': meta.get('product_name', 'Unknown'),
'category': meta.get('category', 'Etc'),
'brand': meta.get('brand', ''),
'price': meta.get('price', 0),
'text_preview': documents[i][:100] + "..." # ํˆดํŒ์šฉ ๋ฏธ๋ฆฌ๋ณด๊ธฐ
})
# 3. (์˜ต์…˜) ์งˆ๋ฌธ ๋ฒกํ„ฐ ์ถ”๊ฐ€
if query_text:
print(f"์งˆ๋ฌธ ๋ฒกํ„ฐ ์ƒ์„ฑ ์ค‘: '{query_text}'")
query_vector = embeddings.embed_query(query_text)
vectors = np.vstack([vectors, np.array(query_vector)])
df_data.append({
'product_name': f"โ“ ์งˆ๋ฌธ: {query_text}",
'category': 'Query',
'brand': '-',
'price': 0,
'text_preview': query_text
})
print("์งˆ๋ฌธ ๋ฒกํ„ฐ๊ฐ€ ๋ฐ์ดํ„ฐ์— ์ถ”๊ฐ€๋˜์—ˆ์Šต๋‹ˆ๋‹ค.")
# 4. ์ฐจ์› ์ถ•์†Œ (UMAP 3D)
print(f"์ฐจ์› ์ถ•์†Œ ์ค‘ (768 -> 3D)... ๋ฐ์ดํ„ฐ ๊ฐœ์ˆ˜: {len(vectors)}")
reducer = umap.UMAP(n_components=3, n_neighbors=15, metric='cosine', random_state=42)
projections = reducer.fit_transform(vectors)
# DataFrame ์ƒ์„ฑ
df = pd.DataFrame(df_data)
df['x'] = projections[:, 0]
df['y'] = projections[:, 1]
df['z'] = projections[:, 2]
# 5. ์‹œ๊ฐํ™” (Plotly 3D)
print("3D ๊ทธ๋ž˜ํ”„ ์ƒ์„ฑ ์ค‘...")
# ๊ธฐ๋ณธ ์‚ฐ์ ๋„ ์ƒ์„ฑ
fig = px.scatter_3d(
df, x='x', y='y', z='z',
color='category',
hover_data=['product_name', 'brand', 'price'],
title='Nyang Chatbot Embedding Space (3D)',
opacity=0.6
)
# ์  ํฌ๊ธฐ ์กฐ์ ˆ (์ „์ฒด์ ์œผ๋กœ ์ž‘๊ฒŒ)
fig.update_traces(marker=dict(size=3))
# ์งˆ๋ฌธ(Query) ์ ์ด ์žˆ๋‹ค๋ฉด ๋ณ„๋„๋กœ ๊ฐ•์กฐ
if query_text:
query_idx = df[df['category'] == 'Query'].index
if not query_idx.empty:
fig.add_trace(go.Scatter3d(
x=df.loc[query_idx, 'x'],
y=df.loc[query_idx, 'y'],
z=df.loc[query_idx, 'z'],
mode='markers',
marker=dict(
size=10,
color='red',
symbol='diamond',
line=dict(width=2, color='white')
),
name='Current Query',
hoverinfo='text',
text=f"ํ˜„์žฌ ์งˆ๋ฌธ: {query_text}"
))
# ์Šคํƒ€์ผ ๊ฐœ์„ 
fig.update_layout(
margin=dict(l=0, r=0, b=0, t=40),
scene=dict(
xaxis=dict(showgrid=True, zeroline=False),
yaxis=dict(showgrid=True, zeroline=False),
zaxis=dict(showgrid=True, zeroline=False)
)
)
# 6. ์ €์žฅ
fig.write_html(OUTPUT_HTML_PATH)
print(f"์‹œ๊ฐํ™” ํŒŒ์ผ ์ €์žฅ ์™„๋ฃŒ: {OUTPUT_HTML_PATH}")
print("์›น ๋ธŒ๋ผ์šฐ์ €๋กœ ํ•ด๋‹น ํŒŒ์ผ์„ ์—ด์–ด๋ณด์„ธ์š”!")
if __name__ == "__main__":
# ํ…Œ์ŠคํŠธ ์งˆ๋ฌธ์„ ๋„ฃ์–ด ๊ฒ€์ƒ‰ ์œ„์น˜๋ฅผ ํ™•์ธํ•ด๋ณผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
test_query = "๊ณ ์–‘์ด ํ„ธ ๊ด€๋ฆฌํ•˜๋Š” ๋น— ์ถ”์ฒœํ•ด์ค˜"
visualize_3d(test_query)