|
|
import torch.nn as nn |
|
|
import torch |
|
|
from transformers import AutoTokenizer |
|
|
import networkx as nx |
|
|
import plotly.graph_objects as go |
|
|
import random |
|
|
|
|
|
def find_similar_embeddings(target_embedding, n=10): |
|
|
""" |
|
|
Find the n most similar embeddings to the target embedding using cosine similarity |
|
|
|
|
|
Args: |
|
|
target_embedding: The embedding vector to compare against |
|
|
n: Number of similar embeddings to return (default 3) |
|
|
|
|
|
Returns: |
|
|
List of tuples containing (word, similarity_score) sorted by similarity |
|
|
""" |
|
|
|
|
|
if not isinstance(target_embedding, torch.Tensor): |
|
|
target_embedding = torch.tensor(target_embedding) |
|
|
|
|
|
|
|
|
all_embeddings = model.embedding.weight |
|
|
|
|
|
|
|
|
similarities = torch.nn.functional.cosine_similarity( |
|
|
target_embedding.unsqueeze(0), |
|
|
all_embeddings |
|
|
) |
|
|
|
|
|
|
|
|
top_n_similarities, top_n_indices = torch.topk(similarities, n) |
|
|
|
|
|
|
|
|
results = [] |
|
|
for idx, score in zip(top_n_indices, top_n_similarities): |
|
|
word = tokenizer.decode(idx) |
|
|
results.append((word, score.item())) |
|
|
|
|
|
return results |
|
|
|
|
|
def prompt_to_embeddings(prompt:str): |
|
|
|
|
|
tokens = tokenizer(prompt, return_tensors="pt") |
|
|
input_ids = tokens['input_ids'] |
|
|
|
|
|
|
|
|
outputs = model(input_ids) |
|
|
|
|
|
|
|
|
embeddings = outputs |
|
|
|
|
|
|
|
|
token_id_list = tokenizer.encode(prompt, add_special_tokens=True) |
|
|
token_str = [tokenizer.decode(t_id, skip_special_tokens=True) for t_id in token_id_list] |
|
|
|
|
|
return token_id_list, embeddings, token_str |
|
|
|
|
|
class EmbeddingModel(nn.Module): |
|
|
def __init__(self, vocab_size, embedding_dim): |
|
|
super(EmbeddingModel, self).__init__() |
|
|
self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim) |
|
|
|
|
|
def forward(self, input_ids): |
|
|
return self.embedding(input_ids) |
|
|
|
|
|
|
|
|
vocab_size = 151936 |
|
|
dimensions = 1536 |
|
|
embeddings_filename = r"python\code\files\embeddings_qwen.pth" |
|
|
tokenizer_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" |
|
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) |
|
|
|
|
|
|
|
|
model = EmbeddingModel(vocab_size, dimensions) |
|
|
|
|
|
|
|
|
saved_embeddings = torch.load(embeddings_filename) |
|
|
|
|
|
|
|
|
if 'weight' not in saved_embeddings: |
|
|
raise KeyError("The saved embeddings file does not contain 'weight' key.") |
|
|
|
|
|
embeddings_tensor = saved_embeddings['weight'] |
|
|
|
|
|
|
|
|
if embeddings_tensor.size() != (vocab_size, dimensions): |
|
|
raise ValueError(f"The dimensions of the loaded embeddings do not match the model's expected dimensions ({vocab_size}, {dimensions}).") |
|
|
|
|
|
|
|
|
model.embedding.weight.data = embeddings_tensor |
|
|
|
|
|
|
|
|
model.eval() |
|
|
|
|
|
token_id_list, prompt_embeddings, prompt_token_str = prompt_to_embeddings("""We propose a new simple network architecture, the Transformer, based solely on attention mechanisms, dispensing with recurrence and convolutions entirely""") |
|
|
|
|
|
tokens_and_neighbors = {} |
|
|
for i in range(1, len(prompt_embeddings[0])): |
|
|
token_results = find_similar_embeddings(prompt_embeddings[0][i], n=40) |
|
|
similar_embs = [] |
|
|
for word, score in token_results: |
|
|
if word.strip().lower() != prompt_token_str[i].strip().lower(): |
|
|
similar_embs.append(word) |
|
|
tokens_and_neighbors[prompt_token_str[i]] = similar_embs |
|
|
|
|
|
all_token_embeddings = {} |
|
|
|
|
|
|
|
|
for token, neighbors in tokens_and_neighbors.items(): |
|
|
|
|
|
token_id, token_emb, _ = prompt_to_embeddings(token) |
|
|
all_token_embeddings[token] = token_emb[0][1] |
|
|
|
|
|
|
|
|
for neighbor in neighbors: |
|
|
|
|
|
neighbor_id, neighbor_emb, _ = prompt_to_embeddings(neighbor) |
|
|
all_token_embeddings[neighbor] = neighbor_emb[0][1] |
|
|
|
|
|
|
|
|
G = nx.Graph() |
|
|
|
|
|
|
|
|
for token, neighbors in tokens_and_neighbors.items(): |
|
|
for neighbor in neighbors: |
|
|
G.add_edge(token, neighbor) |
|
|
|
|
|
|
|
|
k = 2 |
|
|
|
|
|
|
|
|
|
|
|
pos = nx.forceatlas2_layout(G, max_iter=36) |
|
|
|
|
|
|
|
|
viz_width = 1500 |
|
|
viz_height = 500 |
|
|
|
|
|
|
|
|
edge_x, edge_y = [], [] |
|
|
for edge in G.edges(): |
|
|
x0, y0 = pos[edge[0]] |
|
|
x1, y1 = pos[edge[1]] |
|
|
|
|
|
x0, x1 = x0 * viz_width, x1 * viz_width |
|
|
y0, y1 = y0 * viz_height, y1 * viz_height |
|
|
edge_x.extend([x0, x1, None]) |
|
|
edge_y.extend([y0, y1, None]) |
|
|
|
|
|
|
|
|
node_x = [pos[node][0] * viz_width for node in G.nodes()] |
|
|
node_y = [pos[node][1] * viz_height for node in G.nodes()] |
|
|
node_degrees = dict(G.degree()) |
|
|
|
|
|
colors = [] |
|
|
components = list(nx.connected_components(G)) |
|
|
|
|
|
|
|
|
node_to_color = {} |
|
|
node_opacities = [] |
|
|
node_labels = [] |
|
|
hover_labels = [] |
|
|
text_opacities = [] |
|
|
|
|
|
|
|
|
node_component_indices = [] |
|
|
for node in G.nodes(): |
|
|
|
|
|
for i, component in enumerate(components): |
|
|
if node in component: |
|
|
node_component_indices.append(i) |
|
|
break |
|
|
|
|
|
|
|
|
if node in tokens_and_neighbors: |
|
|
node_opacities.append(0.9) |
|
|
text_opacities.append(1.0) |
|
|
node_labels.append(node) |
|
|
hover_labels.append(node) |
|
|
else: |
|
|
node_opacities.append(0.6) |
|
|
text_opacities.append(0.0) |
|
|
node_labels.append(node) |
|
|
hover_labels.append(node) |
|
|
|
|
|
node_sizes = [(degree + 5) * 1 for degree in node_degrees.values()] |
|
|
|
|
|
|
|
|
node_trace = go.Scatter( |
|
|
x=node_x, y=node_y, |
|
|
mode='markers+text', |
|
|
text=node_labels, |
|
|
textposition="top center", |
|
|
textfont=dict( |
|
|
color=[f'rgba(0,0,0,{opacity})' for opacity in text_opacities] |
|
|
), |
|
|
marker=dict( |
|
|
size=node_sizes, |
|
|
color=node_component_indices, |
|
|
colorscale='plasma', |
|
|
opacity=node_opacities, |
|
|
line_width=0.5 |
|
|
), |
|
|
customdata=[[hover_labels[i], ' | '.join(G.neighbors(node))] for i, node in enumerate(G.nodes())], |
|
|
hovertemplate="<b>%{customdata[0]}</b><br>Similar tokens: %{customdata[1]}<extra></extra>", |
|
|
hoverlabel=dict(namelength=0) |
|
|
) |
|
|
|
|
|
|
|
|
edge_trace = go.Scatter( |
|
|
x=edge_x, y=edge_y, |
|
|
line=dict(width=0.5, color='grey'), |
|
|
hoverinfo='none', |
|
|
mode='lines' |
|
|
) |
|
|
|
|
|
|
|
|
fig = go.Figure(data=[edge_trace, node_trace], |
|
|
layout=go.Layout( |
|
|
width=1200, |
|
|
height=400, |
|
|
paper_bgcolor='white', |
|
|
plot_bgcolor='white', |
|
|
showlegend=False, |
|
|
margin=dict(l=0, r=0, t=0, b=0), |
|
|
xaxis=dict( |
|
|
showgrid=False, |
|
|
zeroline=False, |
|
|
showticklabels=False, |
|
|
), |
|
|
yaxis=dict( |
|
|
showgrid=False, |
|
|
zeroline=False, |
|
|
showticklabels=False, |
|
|
scaleanchor="x", |
|
|
scaleratio=1 |
|
|
) |
|
|
)) |
|
|
fig.show() |
|
|
|
|
|
fig.write_html(r"src\fragments\token_visualization.html", |
|
|
include_plotlyjs=False, |
|
|
full_html=False, |
|
|
config={ |
|
|
'displayModeBar': False, |
|
|
'responsive': True, |
|
|
'scrollZoom': False, |
|
|
}) |
|
|
|
|
|
... |