Spaces:
Sleeping
Sleeping
| import os | |
| import pandas as pd | |
| import numpy as np | |
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| from langchain_community.vectorstores import Chroma | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| import umap.umap_ as umap | |
| # ========================================== | |
| # [์คํฌ๋ฆฝํธ ์ค๋ช ] | |
| # ๋ฒกํฐ DB ๊ณ ๊ธ ์๊ฐํ ๋๊ตฌ (3D & Interactive) | |
| # 1. ChromaDB์ ์ ์ฅ๋ ๋ชจ๋ ์ํ ๋ฒกํฐ๋ฅผ ๋ก๋ํฉ๋๋ค. | |
| # 2. UMAP ์๊ณ ๋ฆฌ์ฆ์ผ๋ก 768์ฐจ์ ๋ฒกํฐ๋ฅผ 3์ฐจ์์ผ๋ก ์ถ์ํฉ๋๋ค. | |
| # 3. Plotly๋ฅผ ์ฌ์ฉํ์ฌ ํ์ /์ค์ด ๊ฐ๋ฅํ 3D ์ฐ์ ๋ HTML์ ์์ฑํฉ๋๋ค. | |
| # 4. (์ ํ) ์ง๋ฌธ์ ์ ๋ ฅํ๋ฉด, ์ง๋ฌธ ๋ฒกํฐ์ ์์น๋ ํจ๊ป ํ์ํฉ๋๋ค. | |
| # ========================================== | |
| # --- ์ค์ --- | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| CHROMA_DB_PATH = os.path.join(BASE_DIR, '..', '..', 'data', 'chroma_db') | |
| EMBEDDING_MODEL_PATH = os.path.join(BASE_DIR, '..', '..', 'models', 'snowflake-finetuned-hard') | |
| OUTPUT_HTML_PATH = os.path.join(BASE_DIR, '..', '..', 'embedding_visualization_3d.html') | |
| def visualize_3d(query_text=None): | |
| print("--- 3D ์๋ฒ ๋ฉ ์๊ฐํ ์์ ---") | |
| # 1. ๋ชจ๋ธ ๋ฐ DB ๋ก๋ | |
| print(f"๋ชจ๋ธ ๋ก๋ ์ค: {EMBEDDING_MODEL_PATH}") | |
| embeddings = HuggingFaceEmbeddings( | |
| model_name=EMBEDDING_MODEL_PATH, | |
| model_kwargs={'device': 'cuda'}, | |
| encode_kwargs={'normalize_embeddings': True} | |
| ) | |
| vectorstore = Chroma( | |
| persist_directory=CHROMA_DB_PATH, | |
| embedding_function=embeddings | |
| ) | |
| # 2. ๋ฐ์ดํฐ ์ถ์ถ | |
| print("DB์์ ๋ฐ์ดํฐ ์ถ์ถ ์ค...") | |
| data = vectorstore.get(include=['embeddings', 'metadatas', 'documents']) | |
| if data['embeddings'] is None or len(data['embeddings']) == 0: | |
| print("๋ฐ์ดํฐ๊ฐ ์์ต๋๋ค.") | |
| return | |
| vectors = np.array(data['embeddings']) | |
| metadatas = data['metadatas'] | |
| documents = data['documents'] | |
| # ๋ฉํ๋ฐ์ดํฐ ์ ๋ฆฌ (DataFrame ์์ฑ์ฉ) | |
| df_data = [] | |
| for i, meta in enumerate(metadatas): | |
| df_data.append({ | |
| 'product_name': meta.get('product_name', 'Unknown'), | |
| 'category': meta.get('category', 'Etc'), | |
| 'brand': meta.get('brand', ''), | |
| 'price': meta.get('price', 0), | |
| 'text_preview': documents[i][:100] + "..." # ํดํ์ฉ ๋ฏธ๋ฆฌ๋ณด๊ธฐ | |
| }) | |
| # 3. (์ต์ ) ์ง๋ฌธ ๋ฒกํฐ ์ถ๊ฐ | |
| if query_text: | |
| print(f"์ง๋ฌธ ๋ฒกํฐ ์์ฑ ์ค: '{query_text}'") | |
| query_vector = embeddings.embed_query(query_text) | |
| vectors = np.vstack([vectors, np.array(query_vector)]) | |
| df_data.append({ | |
| 'product_name': f"โ ์ง๋ฌธ: {query_text}", | |
| 'category': 'Query', | |
| 'brand': '-', | |
| 'price': 0, | |
| 'text_preview': query_text | |
| }) | |
| print("์ง๋ฌธ ๋ฒกํฐ๊ฐ ๋ฐ์ดํฐ์ ์ถ๊ฐ๋์์ต๋๋ค.") | |
| # 4. ์ฐจ์ ์ถ์ (UMAP 3D) | |
| print(f"์ฐจ์ ์ถ์ ์ค (768 -> 3D)... ๋ฐ์ดํฐ ๊ฐ์: {len(vectors)}") | |
| reducer = umap.UMAP(n_components=3, n_neighbors=15, metric='cosine', random_state=42) | |
| projections = reducer.fit_transform(vectors) | |
| # DataFrame ์์ฑ | |
| df = pd.DataFrame(df_data) | |
| df['x'] = projections[:, 0] | |
| df['y'] = projections[:, 1] | |
| df['z'] = projections[:, 2] | |
| # 5. ์๊ฐํ (Plotly 3D) | |
| print("3D ๊ทธ๋ํ ์์ฑ ์ค...") | |
| # ๊ธฐ๋ณธ ์ฐ์ ๋ ์์ฑ | |
| fig = px.scatter_3d( | |
| df, x='x', y='y', z='z', | |
| color='category', | |
| hover_data=['product_name', 'brand', 'price'], | |
| title='Nyang Chatbot Embedding Space (3D)', | |
| opacity=0.6 | |
| ) | |
| # ์ ํฌ๊ธฐ ์กฐ์ (์ ์ฒด์ ์ผ๋ก ์๊ฒ) | |
| fig.update_traces(marker=dict(size=3)) | |
| # ์ง๋ฌธ(Query) ์ ์ด ์๋ค๋ฉด ๋ณ๋๋ก ๊ฐ์กฐ | |
| if query_text: | |
| query_idx = df[df['category'] == 'Query'].index | |
| if not query_idx.empty: | |
| fig.add_trace(go.Scatter3d( | |
| x=df.loc[query_idx, 'x'], | |
| y=df.loc[query_idx, 'y'], | |
| z=df.loc[query_idx, 'z'], | |
| mode='markers', | |
| marker=dict( | |
| size=10, | |
| color='red', | |
| symbol='diamond', | |
| line=dict(width=2, color='white') | |
| ), | |
| name='Current Query', | |
| hoverinfo='text', | |
| text=f"ํ์ฌ ์ง๋ฌธ: {query_text}" | |
| )) | |
| # ์คํ์ผ ๊ฐ์ | |
| fig.update_layout( | |
| margin=dict(l=0, r=0, b=0, t=40), | |
| scene=dict( | |
| xaxis=dict(showgrid=True, zeroline=False), | |
| yaxis=dict(showgrid=True, zeroline=False), | |
| zaxis=dict(showgrid=True, zeroline=False) | |
| ) | |
| ) | |
| # 6. ์ ์ฅ | |
| fig.write_html(OUTPUT_HTML_PATH) | |
| print(f"์๊ฐํ ํ์ผ ์ ์ฅ ์๋ฃ: {OUTPUT_HTML_PATH}") | |
| print("์น ๋ธ๋ผ์ฐ์ ๋ก ํด๋น ํ์ผ์ ์ด์ด๋ณด์ธ์!") | |
| if __name__ == "__main__": | |
| # ํ ์คํธ ์ง๋ฌธ์ ๋ฃ์ด ๊ฒ์ ์์น๋ฅผ ํ์ธํด๋ณผ ์ ์์ต๋๋ค. | |
| test_query = "๊ณ ์์ด ํธ ๊ด๋ฆฌํ๋ ๋น ์ถ์ฒํด์ค" | |
| visualize_3d(test_query) | |