import json
from sentence_transformers import SentenceTransformer
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt # matplotlib은 폰트 설정 로직에 필요
import matplotlib.font_manager as fm
import numpy as np
import platform
import os
import networkx as nx # 그래프 구조 생성
import plotly.graph_objects as go # 3D 시각화
from sklearn.metrics.pairwise import cosine_similarity # 유사도 계산
# --- 한글 폰트 설정 함수 ---
def set_korean_font():
"""
현재 운영체제에 맞는 한글 폰트를 matplotlib 및 Plotly용으로 설정 시도하고,
Plotly에서 사용할 폰트 이름을 반환합니다.
"""
system_name = platform.system()
plotly_font_name = None # Plotly에서 사용할 폰트 이름
# Matplotlib 폰트 설정
if system_name == "Windows":
font_name = "Malgun Gothic"
plotly_font_name = "Malgun Gothic"
elif system_name == "Darwin": # MacOS
font_name = "AppleGothic"
plotly_font_name = "AppleGothic"
elif system_name == "Linux":
# Linux에서 선호하는 한글 폰트 경로 또는 이름 설정
font_path = "/usr/share/fonts/truetype/nanum/NanumGothic.ttf"
plotly_font_name_linux = "NanumGothic" # Plotly는 폰트 '이름'을 주로 사용
if os.path.exists(font_path):
font_name = fm.FontProperties(fname=font_path).get_name()
plotly_font_name = plotly_font_name_linux
print(f"Using font: {font_name} from {font_path}")
else:
# 시스템에서 'Nanum' 포함 폰트 찾기 시도
try:
available_fonts = [f.name for f in fm.fontManager.ttflist]
nanum_fonts = [name for name in available_fonts if 'Nanum' in name]
if nanum_fonts:
font_name = nanum_fonts[0]
# Plotly에서 사용할 이름도 비슷하게 설정 (정확한 이름은 시스템마다 다를 수 있음)
plotly_font_name = font_name if 'Nanum' in font_name else plotly_font_name_linux
print(f"Found and using system font: {font_name}")
else:
# 다른 OS 폰트 시도
if "Malgun Gothic" in available_fonts:
font_name = "Malgun Gothic"
plotly_font_name = "Malgun Gothic"
elif "AppleGothic" in available_fonts:
font_name = "AppleGothic"
plotly_font_name = "AppleGothic"
else:
font_name = None
if font_name: print(f"Trying fallback font: {font_name}")
except Exception as e:
print(f"Error finding Linux font: {e}")
font_name = None
if not font_name:
print("Warning: Linux 한글 폰트를 자동으로 찾지 못했습니다. Matplotlib 기본 폰트를 사용합니다.")
font_name = None
plotly_font_name = None # Plotly도 기본값 사용
else: # 기타 OS
font_name = None
plotly_font_name = None
# Matplotlib 폰트 설정 적용
if font_name:
try:
plt.rc('font', family=font_name)
plt.rc('axes', unicode_minus=False)
print(f"Matplotlib font set to: {font_name}")
except Exception as e:
print(f"Error setting Matplotlib font '{font_name}': {e}. Using default.")
plt.rcdefaults()
plt.rc('axes', unicode_minus=False)
# Plotly 폰트 이름도 기본값으로 되돌릴 수 있음 (선택적)
# plotly_font_name = None
else:
print("Matplotlib Korean font not set. Using default font.")
plt.rcdefaults()
plt.rc('axes', unicode_minus=False)
if not plotly_font_name:
print("Plotly font name not explicitly found, will use Plotly default (sans-serif).")
plotly_font_name = 'sans-serif' # Plotly 기본값 지정
print(f"Plotly will try to use font: {plotly_font_name}")
return plotly_font_name # Plotly에서 사용할 폰트 이름 반환
# --- 데이터 로드 함수 ---
def load_titles_from_json(filepath):
""" JSON 파일에서 'title'만 리스트로 로드합니다. """
try:
with open(filepath, 'r', encoding='utf-8') as f:
data = json.load(f)
# data가 리스트 형태라고 가정
if isinstance(data, list):
titles = [item.get('word', '') for item in data if item.get('word')]
# 빈 문자열 제거
titles = [title for title in titles if title]
return titles
else:
print(f"오류: 파일 '{filepath}'의 최상위 형식이 리스트가 아닙니다.")
return None
except FileNotFoundError:
print(f"오류: 파일 '{filepath}'를 찾을 수 없습니다.")
return None
except json.JSONDecodeError:
print(f"오류: 파일 '{filepath}'의 JSON 형식이 잘못되었습니다.")
return None
except Exception as e:
print(f"데이터 로딩 중 오류 발생: {e}")
return None
# --- 메인 실행 부분 ---
if __name__ == "__main__":
# 한글 폰트 설정 (matplotlib용, Plotly용 이름도 받아옴)
plotly_font = set_korean_font()
# --- 설정값 ---
data_file_path = 'child_mind_data.json' # 입력 데이터 파일 경로
embedding_model_name = 'BAAI/bge-m3' # 수정: BGE-M3 모델로 변경
similarity_threshold = 0.7 # 엣지를 생성할 코사인 유사도 임계값 (0.0 ~ 1.0)
tsne_perplexity = 30 # t-SNE perplexity (데이터 수보다 작아야 함)
tsne_max_iter = 1000 # t-SNE 반복 횟수
# ---
# 1. 데이터 로드 (어휘 제목 리스트)
print(f"데이터 로딩 시도: {data_file_path}")
word_list = load_titles_from_json(data_file_path)
if not word_list:
print("시각화할 어휘 데이터가 없습니다. 프로그램을 종료합니다.")
exit() # 데이터 없으면 종료
else:
print(f"총 {len(word_list)}개의 유효한 어휘를 로드했습니다.")
# 중복 제거 (선택적)
original_count = len(word_list)
word_list = sorted(list(set(word_list)))
if len(word_list) < original_count:
print(f"중복 제거 후 {len(word_list)}개의 고유한 어휘가 남았습니다.")
# 2. 임베딩 모델 로드
print(f"임베딩 모델 로딩 중: {embedding_model_name} ...")
try:
model = SentenceTransformer(embedding_model_name)
except Exception as e:
print(f"오류: 임베딩 모델 '{embedding_model_name}' 로딩에 실패했습니다. {e}")
print("인터넷 연결 및 모델 이름을 확인하세요.")
exit()
print("모델 로딩 완료.")
# 3. 임베딩 생성
print("어휘 임베딩 생성 중...")
try:
# BGE 모델 특화 파라미터 추가
embeddings = model.encode(word_list, show_progress_bar=True, normalize_embeddings=True)
except Exception as e:
print(f"오류: 임베딩 생성 중 문제가 발생했습니다. {e}")
exit()
print(f"임베딩 생성 완료. 각 어휘는 {embeddings.shape[1]}차원 벡터로 변환되었습니다.")
# 4. 3D 좌표 생성 - t-SNE 사용
print("3차원 좌표 생성 중 (t-SNE)...")
# perplexity 값 조정 (데이터 수보다 작아야 함)
effective_perplexity = min(tsne_perplexity, len(word_list) - 1)
if effective_perplexity <= 0:
print(f"Warning: 데이터 수가 너무 적어 ({len(word_list)}개) perplexity를 5로 강제 설정합니다.")
effective_perplexity = 5 # 매우 작은 데이터셋 대비
try:
tsne = TSNE(n_components=3, random_state=42, perplexity=effective_perplexity, max_iter=tsne_max_iter, init='pca', learning_rate='auto')
embeddings_3d = tsne.fit_transform(embeddings)
except Exception as e:
print(f"오류: t-SNE 차원 축소 중 문제가 발생했습니다. {e}")
exit()
print("3차원 좌표 생성 완료.")
# 5. 유사도 계산 및 엣지 정의
print("어휘 간 유사도 계산 및 엣지 정의 중...")
try:
similarity_matrix = cosine_similarity(embeddings)
except Exception as e:
print(f"오류: 코사인 유사도 계산 중 문제가 발생했습니다. {e}")
exit()
edges = []
edge_weights = [] # 엣지 두께 등에 활용할 가중치
for i in range(len(word_list)):
for j in range(i + 1, len(word_list)): # 중복 및 자기 자신 연결 방지
similarity = similarity_matrix[i, j]
if similarity > similarity_threshold:
edges.append((word_list[i], word_list[j]))
edge_weights.append(similarity) # 유사도 값을 가중치로 사용
print(f"유사도 임계값 ({similarity_threshold}) 초과 엣지 {len(edges)}개 정의 완료.")
if not edges:
print("Warning: 정의된 엣지가 없습니다. 유사도 임계값이 너무 높거나 데이터 간 유사성이 낮을 수 있습니다.")
# 엣지가 없어도 노드만 표시하도록 계속 진행
# 6. NetworkX 그래프 생성
print("NetworkX 그래프 객체 생성 중...")
G = nx.Graph()
for i, word in enumerate(word_list):
# 노드 속성으로 3D 좌표 저장
G.add_node(word, pos=(embeddings_3d[i, 0], embeddings_3d[i, 1], embeddings_3d[i, 2]))
# 엣지와 가중치 추가
for edge, weight in zip(edges, edge_weights):
G.add_edge(edge[0], edge[1], weight=weight)
print("NetworkX 그래프 생성 완료.")
# --- Plotly를 사용한 3D 시각화 ---
print("Plotly 3D 그래프 생성 중...")
# 엣지 좌표 추출
edge_x = []
edge_y = []
edge_z = []
if edges: # 엣지가 있을 경우에만 처리
for edge in G.edges():
x0, y0, z0 = G.nodes[edge[0]]['pos']
x1, y1, z1 = G.nodes[edge[1]]['pos']
edge_x.extend([x0, x1, None]) # None을 넣어 선을 분리
edge_y.extend([y0, y1, None])
edge_z.extend([z0, z1, None])
# 엣지용 Scatter3d 트레이스 생성
edge_trace = go.Scatter3d(
x=edge_x, y=edge_y, z=edge_z,
mode='lines',
line=dict(width=1, color='#888'), # 엣지 색상 및 두께
hoverinfo='none' # 엣지에는 호버 정보 없음
)
else:
edge_trace = go.Scatter3d(x=[], y=[], z=[], mode='lines') # 엣지 없으면 빈 트레이스
# 노드 위치와 텍스트 추출
node_x = [G.nodes[node]['pos'][0] for node in G.nodes()]
node_y = [G.nodes[node]['pos'][1] for node in G.nodes()]
node_z = [G.nodes[node]['pos'][2] for node in G.nodes()]
node_text = list(G.nodes()) # 노드 이름 (어휘)
node_adjacencies = [] # 연결된 엣지 수 (마커 크기 등에 활용 가능)
node_hover_text = [] # 노드 호버 텍스트
for node, adjacencies in enumerate(G.adjacency()):
num_connections = len(adjacencies[1])
node_adjacencies.append(num_connections)
node_hover_text.append(f'{node_text[node]}
Connections: {num_connections}')
# 노드용 Scatter3d 트레이스 생성
node_trace = go.Scatter3d(
x=node_x, y=node_y, z=node_z,
mode='markers+text',
text=node_text,
hovertext=node_hover_text,
hoverinfo='text',
textposition='top center',
textfont=dict(
size=10,
color='black',
family=plotly_font
),
marker=dict(
size=6,
color=node_z,
colorscale='Viridis',
opacity=0.9,
colorbar=dict(thickness=15, title='Node Depth (Z-axis)', xanchor='left', title_side='right')
# titleside → title_side
)
)
# 레이아웃 설정
layout = go.Layout(
title=dict(
text=f'어휘 의미 유사성 기반 3D 그래프 (BGE-M3, Threshold: {similarity_threshold})',
font=dict(size=16, family=plotly_font)
),
showlegend=False,
hovermode='closest', # 가장 가까운 데이터 포인트 정보 표시
margin=dict(b=20, l=5, r=5, t=40), # 여백
scene=dict( # 3D 씬 설정
xaxis=dict(title='TSNE Dimension 1', showticklabels=False, backgroundcolor="rgb(230, 230,230)", gridcolor="white", zerolinecolor="white"),
yaxis=dict(title='TSNE Dimension 2', showticklabels=False, backgroundcolor="rgb(230, 230,230)", gridcolor="white", zerolinecolor="white"),
zaxis=dict(title='TSNE Dimension 3', showticklabels=False, backgroundcolor="rgb(230, 230,230)", gridcolor="white", zerolinecolor="white"),
aspectratio=dict(x=1, y=1, z=0.8) # 축 비율 조정
),
# 주석 추가 (옵션)
# annotations=[
# dict(
# showarrow=False,
# text=f"Data: {data_file_path}
Model: {embedding_model_name}",
# xref="paper", yref="paper",
# x=0.005, y=0.005
# )
# ]
)
# Figure 생성 및 표시
fig = go.Figure(data=[edge_trace, node_trace], layout=layout)
print("*"*20)
print(" 인터랙티브 3D 그래프를 표시합니다. ")
print(" - 마우스 휠: 줌 인/아웃")
print(" - 마우스 드래그: 회전")
print(" - 노드 위에 마우스 올리기: 어휘 이름 및 연결 수 확인")
print("*"*20)
# HTML 파일로 저장 (선택적)
# fig.write_html("3d_graph_visualization.html")
# print("그래프를 '3d_graph_visualization.html' 파일로 저장했습니다.")
fig.show() # 웹 브라우저 또는 IDE 출력 창에 표시
print("그래프 표시 완료.")