Spaces:
Sleeping
Sleeping
| import json | |
| from sentence_transformers import SentenceTransformer | |
| from sklearn.manifold import TSNE | |
| import matplotlib.pyplot as plt # matplotlib์ ํฐํธ ์ค์ ๋ก์ง์ ํ์ | |
| import matplotlib.font_manager as fm | |
| import numpy as np | |
| import platform | |
| import os | |
| import networkx as nx # ๊ทธ๋ํ ๊ตฌ์กฐ ์์ฑ | |
| import plotly.graph_objects as go # 3D ์๊ฐํ | |
| from sklearn.metrics.pairwise import cosine_similarity # ์ ์ฌ๋ ๊ณ์ฐ | |
| # --- ํ๊ธ ํฐํธ ์ค์ ํจ์ --- | |
| def set_korean_font(): | |
| """ | |
| ํ์ฌ ์ด์์ฒด์ ์ ๋ง๋ ํ๊ธ ํฐํธ๋ฅผ matplotlib ๋ฐ Plotly์ฉ์ผ๋ก ์ค์ ์๋ํ๊ณ , | |
| Plotly์์ ์ฌ์ฉํ ํฐํธ ์ด๋ฆ์ ๋ฐํํฉ๋๋ค. | |
| """ | |
| system_name = platform.system() | |
| plotly_font_name = None # Plotly์์ ์ฌ์ฉํ ํฐํธ ์ด๋ฆ | |
| # Matplotlib ํฐํธ ์ค์ | |
| if system_name == "Windows": | |
| font_name = "Malgun Gothic" | |
| plotly_font_name = "Malgun Gothic" | |
| elif system_name == "Darwin": # MacOS | |
| font_name = "AppleGothic" | |
| plotly_font_name = "AppleGothic" | |
| elif system_name == "Linux": | |
| # Linux์์ ์ ํธํ๋ ํ๊ธ ํฐํธ ๊ฒฝ๋ก ๋๋ ์ด๋ฆ ์ค์ | |
| font_path = "/usr/share/fonts/truetype/nanum/NanumGothic.ttf" | |
| plotly_font_name_linux = "NanumGothic" # Plotly๋ ํฐํธ '์ด๋ฆ'์ ์ฃผ๋ก ์ฌ์ฉ | |
| if os.path.exists(font_path): | |
| font_name = fm.FontProperties(fname=font_path).get_name() | |
| plotly_font_name = plotly_font_name_linux | |
| print(f"Using font: {font_name} from {font_path}") | |
| else: | |
| # ์์คํ ์์ 'Nanum' ํฌํจ ํฐํธ ์ฐพ๊ธฐ ์๋ | |
| try: | |
| available_fonts = [f.name for f in fm.fontManager.ttflist] | |
| nanum_fonts = [name for name in available_fonts if 'Nanum' in name] | |
| if nanum_fonts: | |
| font_name = nanum_fonts[0] | |
| # Plotly์์ ์ฌ์ฉํ ์ด๋ฆ๋ ๋น์ทํ๊ฒ ์ค์ (์ ํํ ์ด๋ฆ์ ์์คํ ๋ง๋ค ๋ค๋ฅผ ์ ์์) | |
| plotly_font_name = font_name if 'Nanum' in font_name else plotly_font_name_linux | |
| print(f"Found and using system font: {font_name}") | |
| else: | |
| # ๋ค๋ฅธ OS ํฐํธ ์๋ | |
| if "Malgun Gothic" in available_fonts: | |
| font_name = "Malgun Gothic" | |
| plotly_font_name = "Malgun Gothic" | |
| elif "AppleGothic" in available_fonts: | |
| font_name = "AppleGothic" | |
| plotly_font_name = "AppleGothic" | |
| else: | |
| font_name = None | |
| if font_name: print(f"Trying fallback font: {font_name}") | |
| except Exception as e: | |
| print(f"Error finding Linux font: {e}") | |
| font_name = None | |
| if not font_name: | |
| print("Warning: Linux ํ๊ธ ํฐํธ๋ฅผ ์๋์ผ๋ก ์ฐพ์ง ๋ชปํ์ต๋๋ค. Matplotlib ๊ธฐ๋ณธ ํฐํธ๋ฅผ ์ฌ์ฉํฉ๋๋ค.") | |
| font_name = None | |
| plotly_font_name = None # Plotly๋ ๊ธฐ๋ณธ๊ฐ ์ฌ์ฉ | |
| else: # ๊ธฐํ OS | |
| font_name = None | |
| plotly_font_name = None | |
| # Matplotlib ํฐํธ ์ค์ ์ ์ฉ | |
| if font_name: | |
| try: | |
| plt.rc('font', family=font_name) | |
| plt.rc('axes', unicode_minus=False) | |
| print(f"Matplotlib font set to: {font_name}") | |
| except Exception as e: | |
| print(f"Error setting Matplotlib font '{font_name}': {e}. Using default.") | |
| plt.rcdefaults() | |
| plt.rc('axes', unicode_minus=False) | |
| # Plotly ํฐํธ ์ด๋ฆ๋ ๊ธฐ๋ณธ๊ฐ์ผ๋ก ๋๋๋ฆด ์ ์์ (์ ํ์ ) | |
| # plotly_font_name = None | |
| else: | |
| print("Matplotlib Korean font not set. Using default font.") | |
| plt.rcdefaults() | |
| plt.rc('axes', unicode_minus=False) | |
| if not plotly_font_name: | |
| print("Plotly font name not explicitly found, will use Plotly default (sans-serif).") | |
| plotly_font_name = 'sans-serif' # Plotly ๊ธฐ๋ณธ๊ฐ ์ง์ | |
| print(f"Plotly will try to use font: {plotly_font_name}") | |
| return plotly_font_name # Plotly์์ ์ฌ์ฉํ ํฐํธ ์ด๋ฆ ๋ฐํ | |
| # --- ๋ฐ์ดํฐ ๋ก๋ ํจ์ --- | |
| def load_titles_from_json(filepath): | |
| """ JSON ํ์ผ์์ 'title'๋ง ๋ฆฌ์คํธ๋ก ๋ก๋ํฉ๋๋ค. """ | |
| try: | |
| with open(filepath, 'r', encoding='utf-8') as f: | |
| data = json.load(f) | |
| # data๊ฐ ๋ฆฌ์คํธ ํํ๋ผ๊ณ ๊ฐ์ | |
| if isinstance(data, list): | |
| titles = [item.get('word', '') for item in data if item.get('word')] | |
| # ๋น ๋ฌธ์์ด ์ ๊ฑฐ | |
| titles = [title for title in titles if title] | |
| return titles | |
| else: | |
| print(f"์ค๋ฅ: ํ์ผ '{filepath}'์ ์ต์์ ํ์์ด ๋ฆฌ์คํธ๊ฐ ์๋๋๋ค.") | |
| return None | |
| except FileNotFoundError: | |
| print(f"์ค๋ฅ: ํ์ผ '{filepath}'๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค.") | |
| return None | |
| except json.JSONDecodeError: | |
| print(f"์ค๋ฅ: ํ์ผ '{filepath}'์ JSON ํ์์ด ์๋ชป๋์์ต๋๋ค.") | |
| return None | |
| except Exception as e: | |
| print(f"๋ฐ์ดํฐ ๋ก๋ฉ ์ค ์ค๋ฅ ๋ฐ์: {e}") | |
| return None | |
| # --- ๋ฉ์ธ ์คํ ๋ถ๋ถ --- | |
| if __name__ == "__main__": | |
| # ํ๊ธ ํฐํธ ์ค์ (matplotlib์ฉ, Plotly์ฉ ์ด๋ฆ๋ ๋ฐ์์ด) | |
| plotly_font = set_korean_font() | |
| # --- ์ค์ ๊ฐ --- | |
| data_file_path = 'child_mind_data.json' # ์ ๋ ฅ ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก | |
| embedding_model_name = 'BAAI/bge-m3' # ์์ : BGE-M3 ๋ชจ๋ธ๋ก ๋ณ๊ฒฝ | |
| similarity_threshold = 0.7 # ์ฃ์ง๋ฅผ ์์ฑํ ์ฝ์ฌ์ธ ์ ์ฌ๋ ์๊ณ๊ฐ (0.0 ~ 1.0) | |
| tsne_perplexity = 30 # t-SNE perplexity (๋ฐ์ดํฐ ์๋ณด๋ค ์์์ผ ํจ) | |
| tsne_max_iter = 1000 # t-SNE ๋ฐ๋ณต ํ์ | |
| # --- | |
| # 1. ๋ฐ์ดํฐ ๋ก๋ (์ดํ ์ ๋ชฉ ๋ฆฌ์คํธ) | |
| print(f"๋ฐ์ดํฐ ๋ก๋ฉ ์๋: {data_file_path}") | |
| word_list = load_titles_from_json(data_file_path) | |
| if not word_list: | |
| print("์๊ฐํํ ์ดํ ๋ฐ์ดํฐ๊ฐ ์์ต๋๋ค. ํ๋ก๊ทธ๋จ์ ์ข ๋ฃํฉ๋๋ค.") | |
| exit() # ๋ฐ์ดํฐ ์์ผ๋ฉด ์ข ๋ฃ | |
| else: | |
| print(f"์ด {len(word_list)}๊ฐ์ ์ ํจํ ์ดํ๋ฅผ ๋ก๋ํ์ต๋๋ค.") | |
| # ์ค๋ณต ์ ๊ฑฐ (์ ํ์ ) | |
| original_count = len(word_list) | |
| word_list = sorted(list(set(word_list))) | |
| if len(word_list) < original_count: | |
| print(f"์ค๋ณต ์ ๊ฑฐ ํ {len(word_list)}๊ฐ์ ๊ณ ์ ํ ์ดํ๊ฐ ๋จ์์ต๋๋ค.") | |
| # 2. ์๋ฒ ๋ฉ ๋ชจ๋ธ ๋ก๋ | |
| print(f"์๋ฒ ๋ฉ ๋ชจ๋ธ ๋ก๋ฉ ์ค: {embedding_model_name} ...") | |
| try: | |
| model = SentenceTransformer(embedding_model_name) | |
| except Exception as e: | |
| print(f"์ค๋ฅ: ์๋ฒ ๋ฉ ๋ชจ๋ธ '{embedding_model_name}' ๋ก๋ฉ์ ์คํจํ์ต๋๋ค. {e}") | |
| print("์ธํฐ๋ท ์ฐ๊ฒฐ ๋ฐ ๋ชจ๋ธ ์ด๋ฆ์ ํ์ธํ์ธ์.") | |
| exit() | |
| print("๋ชจ๋ธ ๋ก๋ฉ ์๋ฃ.") | |
| # 3. ์๋ฒ ๋ฉ ์์ฑ | |
| print("์ดํ ์๋ฒ ๋ฉ ์์ฑ ์ค...") | |
| try: | |
| # BGE ๋ชจ๋ธ ํนํ ํ๋ผ๋ฏธํฐ ์ถ๊ฐ | |
| embeddings = model.encode(word_list, show_progress_bar=True, normalize_embeddings=True) | |
| except Exception as e: | |
| print(f"์ค๋ฅ: ์๋ฒ ๋ฉ ์์ฑ ์ค ๋ฌธ์ ๊ฐ ๋ฐ์ํ์ต๋๋ค. {e}") | |
| exit() | |
| print(f"์๋ฒ ๋ฉ ์์ฑ ์๋ฃ. ๊ฐ ์ดํ๋ {embeddings.shape[1]}์ฐจ์ ๋ฒกํฐ๋ก ๋ณํ๋์์ต๋๋ค.") | |
| # 4. 3D ์ขํ ์์ฑ - t-SNE ์ฌ์ฉ | |
| print("3์ฐจ์ ์ขํ ์์ฑ ์ค (t-SNE)...") | |
| # perplexity ๊ฐ ์กฐ์ (๋ฐ์ดํฐ ์๋ณด๋ค ์์์ผ ํจ) | |
| effective_perplexity = min(tsne_perplexity, len(word_list) - 1) | |
| if effective_perplexity <= 0: | |
| print(f"Warning: ๋ฐ์ดํฐ ์๊ฐ ๋๋ฌด ์ ์ด ({len(word_list)}๊ฐ) perplexity๋ฅผ 5๋ก ๊ฐ์ ์ค์ ํฉ๋๋ค.") | |
| effective_perplexity = 5 # ๋งค์ฐ ์์ ๋ฐ์ดํฐ์ ๋๋น | |
| try: | |
| tsne = TSNE(n_components=3, random_state=42, perplexity=effective_perplexity, max_iter=tsne_max_iter, init='pca', learning_rate='auto') | |
| embeddings_3d = tsne.fit_transform(embeddings) | |
| except Exception as e: | |
| print(f"์ค๋ฅ: t-SNE ์ฐจ์ ์ถ์ ์ค ๋ฌธ์ ๊ฐ ๋ฐ์ํ์ต๋๋ค. {e}") | |
| exit() | |
| print("3์ฐจ์ ์ขํ ์์ฑ ์๋ฃ.") | |
| # 5. ์ ์ฌ๋ ๊ณ์ฐ ๋ฐ ์ฃ์ง ์ ์ | |
| print("์ดํ ๊ฐ ์ ์ฌ๋ ๊ณ์ฐ ๋ฐ ์ฃ์ง ์ ์ ์ค...") | |
| try: | |
| similarity_matrix = cosine_similarity(embeddings) | |
| except Exception as e: | |
| print(f"์ค๋ฅ: ์ฝ์ฌ์ธ ์ ์ฌ๋ ๊ณ์ฐ ์ค ๋ฌธ์ ๊ฐ ๋ฐ์ํ์ต๋๋ค. {e}") | |
| exit() | |
| edges = [] | |
| edge_weights = [] # ์ฃ์ง ๋๊ป ๋ฑ์ ํ์ฉํ ๊ฐ์ค์น | |
| for i in range(len(word_list)): | |
| for j in range(i + 1, len(word_list)): # ์ค๋ณต ๋ฐ ์๊ธฐ ์์ ์ฐ๊ฒฐ ๋ฐฉ์ง | |
| similarity = similarity_matrix[i, j] | |
| if similarity > similarity_threshold: | |
| edges.append((word_list[i], word_list[j])) | |
| edge_weights.append(similarity) # ์ ์ฌ๋ ๊ฐ์ ๊ฐ์ค์น๋ก ์ฌ์ฉ | |
| print(f"์ ์ฌ๋ ์๊ณ๊ฐ ({similarity_threshold}) ์ด๊ณผ ์ฃ์ง {len(edges)}๊ฐ ์ ์ ์๋ฃ.") | |
| if not edges: | |
| print("Warning: ์ ์๋ ์ฃ์ง๊ฐ ์์ต๋๋ค. ์ ์ฌ๋ ์๊ณ๊ฐ์ด ๋๋ฌด ๋๊ฑฐ๋ ๋ฐ์ดํฐ ๊ฐ ์ ์ฌ์ฑ์ด ๋ฎ์ ์ ์์ต๋๋ค.") | |
| # ์ฃ์ง๊ฐ ์์ด๋ ๋ ธ๋๋ง ํ์ํ๋๋ก ๊ณ์ ์งํ | |
| # 6. NetworkX ๊ทธ๋ํ ์์ฑ | |
| print("NetworkX ๊ทธ๋ํ ๊ฐ์ฒด ์์ฑ ์ค...") | |
| G = nx.Graph() | |
| for i, word in enumerate(word_list): | |
| # ๋ ธ๋ ์์ฑ์ผ๋ก 3D ์ขํ ์ ์ฅ | |
| G.add_node(word, pos=(embeddings_3d[i, 0], embeddings_3d[i, 1], embeddings_3d[i, 2])) | |
| # ์ฃ์ง์ ๊ฐ์ค์น ์ถ๊ฐ | |
| for edge, weight in zip(edges, edge_weights): | |
| G.add_edge(edge[0], edge[1], weight=weight) | |
| print("NetworkX ๊ทธ๋ํ ์์ฑ ์๋ฃ.") | |
| # --- Plotly๋ฅผ ์ฌ์ฉํ 3D ์๊ฐํ --- | |
| print("Plotly 3D ๊ทธ๋ํ ์์ฑ ์ค...") | |
| # ์ฃ์ง ์ขํ ์ถ์ถ | |
| edge_x = [] | |
| edge_y = [] | |
| edge_z = [] | |
| if edges: # ์ฃ์ง๊ฐ ์์ ๊ฒฝ์ฐ์๋ง ์ฒ๋ฆฌ | |
| for edge in G.edges(): | |
| x0, y0, z0 = G.nodes[edge[0]]['pos'] | |
| x1, y1, z1 = G.nodes[edge[1]]['pos'] | |
| edge_x.extend([x0, x1, None]) # None์ ๋ฃ์ด ์ ์ ๋ถ๋ฆฌ | |
| edge_y.extend([y0, y1, None]) | |
| edge_z.extend([z0, z1, None]) | |
| # ์ฃ์ง์ฉ Scatter3d ํธ๋ ์ด์ค ์์ฑ | |
| edge_trace = go.Scatter3d( | |
| x=edge_x, y=edge_y, z=edge_z, | |
| mode='lines', | |
| line=dict(width=1, color='#888'), # ์ฃ์ง ์์ ๋ฐ ๋๊ป | |
| hoverinfo='none' # ์ฃ์ง์๋ ํธ๋ฒ ์ ๋ณด ์์ | |
| ) | |
| else: | |
| edge_trace = go.Scatter3d(x=[], y=[], z=[], mode='lines') # ์ฃ์ง ์์ผ๋ฉด ๋น ํธ๋ ์ด์ค | |
| # ๋ ธ๋ ์์น์ ํ ์คํธ ์ถ์ถ | |
| node_x = [G.nodes[node]['pos'][0] for node in G.nodes()] | |
| node_y = [G.nodes[node]['pos'][1] for node in G.nodes()] | |
| node_z = [G.nodes[node]['pos'][2] for node in G.nodes()] | |
| node_text = list(G.nodes()) # ๋ ธ๋ ์ด๋ฆ (์ดํ) | |
| node_adjacencies = [] # ์ฐ๊ฒฐ๋ ์ฃ์ง ์ (๋ง์ปค ํฌ๊ธฐ ๋ฑ์ ํ์ฉ ๊ฐ๋ฅ) | |
| node_hover_text = [] # ๋ ธ๋ ํธ๋ฒ ํ ์คํธ | |
| for node, adjacencies in enumerate(G.adjacency()): | |
| num_connections = len(adjacencies[1]) | |
| node_adjacencies.append(num_connections) | |
| node_hover_text.append(f'{node_text[node]}<br>Connections: {num_connections}') | |
| # ๋ ธ๋์ฉ Scatter3d ํธ๋ ์ด์ค ์์ฑ | |
| node_trace = go.Scatter3d( | |
| x=node_x, y=node_y, z=node_z, | |
| mode='markers+text', | |
| text=node_text, | |
| hovertext=node_hover_text, | |
| hoverinfo='text', | |
| textposition='top center', | |
| textfont=dict( | |
| size=10, | |
| color='black', | |
| family=plotly_font | |
| ), | |
| marker=dict( | |
| size=6, | |
| color=node_z, | |
| colorscale='Viridis', | |
| opacity=0.9, | |
| colorbar=dict(thickness=15, title='Node Depth (Z-axis)', xanchor='left', title_side='right') | |
| # titleside โ title_side | |
| ) | |
| ) | |
| # ๋ ์ด์์ ์ค์ | |
| layout = go.Layout( | |
| title=dict( | |
| text=f'์ดํ ์๋ฏธ ์ ์ฌ์ฑ ๊ธฐ๋ฐ 3D ๊ทธ๋ํ (BGE-M3, Threshold: {similarity_threshold})', | |
| font=dict(size=16, family=plotly_font) | |
| ), | |
| showlegend=False, | |
| hovermode='closest', # ๊ฐ์ฅ ๊ฐ๊น์ด ๋ฐ์ดํฐ ํฌ์ธํธ ์ ๋ณด ํ์ | |
| margin=dict(b=20, l=5, r=5, t=40), # ์ฌ๋ฐฑ | |
| scene=dict( # 3D ์ฌ ์ค์ | |
| xaxis=dict(title='TSNE Dimension 1', showticklabels=False, backgroundcolor="rgb(230, 230,230)", gridcolor="white", zerolinecolor="white"), | |
| yaxis=dict(title='TSNE Dimension 2', showticklabels=False, backgroundcolor="rgb(230, 230,230)", gridcolor="white", zerolinecolor="white"), | |
| zaxis=dict(title='TSNE Dimension 3', showticklabels=False, backgroundcolor="rgb(230, 230,230)", gridcolor="white", zerolinecolor="white"), | |
| aspectratio=dict(x=1, y=1, z=0.8) # ์ถ ๋น์จ ์กฐ์ | |
| ), | |
| # ์ฃผ์ ์ถ๊ฐ (์ต์ ) | |
| # annotations=[ | |
| # dict( | |
| # showarrow=False, | |
| # text=f"Data: {data_file_path}<br>Model: {embedding_model_name}", | |
| # xref="paper", yref="paper", | |
| # x=0.005, y=0.005 | |
| # ) | |
| # ] | |
| ) | |
| # Figure ์์ฑ ๋ฐ ํ์ | |
| fig = go.Figure(data=[edge_trace, node_trace], layout=layout) | |
| print("*"*20) | |
| print(" ์ธํฐ๋ํฐ๋ธ 3D ๊ทธ๋ํ๋ฅผ ํ์ํฉ๋๋ค. ") | |
| print(" - ๋ง์ฐ์ค ํ : ์ค ์ธ/์์") | |
| print(" - ๋ง์ฐ์ค ๋๋๊ทธ: ํ์ ") | |
| print(" - ๋ ธ๋ ์์ ๋ง์ฐ์ค ์ฌ๋ฆฌ๊ธฐ: ์ดํ ์ด๋ฆ ๋ฐ ์ฐ๊ฒฐ ์ ํ์ธ") | |
| print("*"*20) | |
| # HTML ํ์ผ๋ก ์ ์ฅ (์ ํ์ ) | |
| # fig.write_html("3d_graph_visualization.html") | |
| # print("๊ทธ๋ํ๋ฅผ '3d_graph_visualization.html' ํ์ผ๋ก ์ ์ฅํ์ต๋๋ค.") | |
| fig.show() # ์น ๋ธ๋ผ์ฐ์ ๋๋ IDE ์ถ๋ ฅ ์ฐฝ์ ํ์ | |
| print("๊ทธ๋ํ ํ์ ์๋ฃ.") |