| | import torch.nn as nn |
| | import torch |
| | from transformers import AutoTokenizer |
| | import networkx as nx |
| | import plotly.graph_objects as go |
| | import random |
| |
|
| | def find_similar_embeddings(target_embedding, n=10): |
| | """ |
| | Find the n most similar embeddings to the target embedding using cosine similarity |
| | |
| | Args: |
| | target_embedding: The embedding vector to compare against |
| | n: Number of similar embeddings to return (default 3) |
| | |
| | Returns: |
| | List of tuples containing (word, similarity_score) sorted by similarity |
| | """ |
| | |
| | if not isinstance(target_embedding, torch.Tensor): |
| | target_embedding = torch.tensor(target_embedding) |
| | |
| | |
| | all_embeddings = model.embedding.weight |
| | |
| | |
| | similarities = torch.nn.functional.cosine_similarity( |
| | target_embedding.unsqueeze(0), |
| | all_embeddings |
| | ) |
| | |
| | |
| | top_n_similarities, top_n_indices = torch.topk(similarities, n) |
| | |
| | |
| | results = [] |
| | for idx, score in zip(top_n_indices, top_n_similarities): |
| | word = tokenizer.decode(idx) |
| | results.append((word, score.item())) |
| | |
| | return results |
| |
|
| | def prompt_to_embeddings(prompt:str): |
| | |
| | tokens = tokenizer(prompt, return_tensors="pt") |
| | input_ids = tokens['input_ids'] |
| |
|
| | |
| | outputs = model(input_ids) |
| |
|
| | |
| | embeddings = outputs |
| |
|
| | |
| | token_id_list = tokenizer.encode(prompt, add_special_tokens=True) |
| | token_str = [tokenizer.decode(t_id, skip_special_tokens=True) for t_id in token_id_list] |
| |
|
| | return token_id_list, embeddings, token_str |
| |
|
| | class EmbeddingModel(nn.Module): |
| | def __init__(self, vocab_size, embedding_dim): |
| | super(EmbeddingModel, self).__init__() |
| | self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim) |
| |
|
| | def forward(self, input_ids): |
| | return self.embedding(input_ids) |
| | |
| |
|
| | vocab_size = 151936 |
| | dimensions = 1536 |
| | embeddings_filename = r"python\code\files\embeddings_qwen.pth" |
| | tokenizer_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" |
| | tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) |
| |
|
| | |
| | model = EmbeddingModel(vocab_size, dimensions) |
| |
|
| | |
| | saved_embeddings = torch.load(embeddings_filename) |
| |
|
| | |
| | if 'weight' not in saved_embeddings: |
| | raise KeyError("The saved embeddings file does not contain 'weight' key.") |
| |
|
| | embeddings_tensor = saved_embeddings['weight'] |
| |
|
| | |
| | if embeddings_tensor.size() != (vocab_size, dimensions): |
| | raise ValueError(f"The dimensions of the loaded embeddings do not match the model's expected dimensions ({vocab_size}, {dimensions}).") |
| |
|
| | |
| | model.embedding.weight.data = embeddings_tensor |
| |
|
| | |
| | model.eval() |
| |
|
| | token_id_list, prompt_embeddings, prompt_token_str = prompt_to_embeddings("""We propose a new simple network architecture, the Transformer, based solely on attention mechanisms, dispensing with recurrence and convolutions entirely""") |
| |
|
| | tokens_and_neighbors = {} |
| | for i in range(1, len(prompt_embeddings[0])): |
| | token_results = find_similar_embeddings(prompt_embeddings[0][i], n=40) |
| | similar_embs = [] |
| | for word, score in token_results: |
| | if word.strip().lower() != prompt_token_str[i].strip().lower(): |
| | similar_embs.append(word) |
| | tokens_and_neighbors[prompt_token_str[i]] = similar_embs |
| |
|
| | all_token_embeddings = {} |
| |
|
| | |
| | for token, neighbors in tokens_and_neighbors.items(): |
| | |
| | token_id, token_emb, _ = prompt_to_embeddings(token) |
| | all_token_embeddings[token] = token_emb[0][1] |
| | |
| | |
| | for neighbor in neighbors: |
| | |
| | neighbor_id, neighbor_emb, _ = prompt_to_embeddings(neighbor) |
| | all_token_embeddings[neighbor] = neighbor_emb[0][1] |
| |
|
| | |
| | G = nx.Graph() |
| |
|
| | |
| | for token, neighbors in tokens_and_neighbors.items(): |
| | for neighbor in neighbors: |
| | G.add_edge(token, neighbor) |
| |
|
| | |
| | k = 2 |
| | |
| | |
| | |
| | pos = nx.forceatlas2_layout(G, max_iter=36) |
| |
|
| | |
| | viz_width = 1500 |
| | viz_height = 500 |
| |
|
| | |
| | edge_x, edge_y = [], [] |
| | for edge in G.edges(): |
| | x0, y0 = pos[edge[0]] |
| | x1, y1 = pos[edge[1]] |
| | |
| | x0, x1 = x0 * viz_width, x1 * viz_width |
| | y0, y1 = y0 * viz_height, y1 * viz_height |
| | edge_x.extend([x0, x1, None]) |
| | edge_y.extend([y0, y1, None]) |
| |
|
| | |
| | node_x = [pos[node][0] * viz_width for node in G.nodes()] |
| | node_y = [pos[node][1] * viz_height for node in G.nodes()] |
| | node_degrees = dict(G.degree()) |
| | |
| | colors = [] |
| | components = list(nx.connected_components(G)) |
| |
|
| | |
| | node_to_color = {} |
| | node_opacities = [] |
| | node_labels = [] |
| | hover_labels = [] |
| | text_opacities = [] |
| |
|
| | |
| | node_component_indices = [] |
| | for node in G.nodes(): |
| | |
| | for i, component in enumerate(components): |
| | if node in component: |
| | node_component_indices.append(i) |
| | break |
| | |
| | |
| | if node in tokens_and_neighbors: |
| | node_opacities.append(0.9) |
| | text_opacities.append(1.0) |
| | node_labels.append(node) |
| | hover_labels.append(node) |
| | else: |
| | node_opacities.append(0.6) |
| | text_opacities.append(0.0) |
| | node_labels.append(node) |
| | hover_labels.append(node) |
| |
|
| | node_sizes = [(degree + 5) * 1 for degree in node_degrees.values()] |
| |
|
| | |
| | node_trace = go.Scatter( |
| | x=node_x, y=node_y, |
| | mode='markers+text', |
| | text=node_labels, |
| | textposition="top center", |
| | textfont=dict( |
| | color=[f'rgba(0,0,0,{opacity})' for opacity in text_opacities] |
| | ), |
| | marker=dict( |
| | size=node_sizes, |
| | color=node_component_indices, |
| | colorscale='plasma', |
| | opacity=node_opacities, |
| | line_width=0.5 |
| | ), |
| | customdata=[[hover_labels[i], ' | '.join(G.neighbors(node))] for i, node in enumerate(G.nodes())], |
| | hovertemplate="<b>%{customdata[0]}</b><br>Similar tokens: %{customdata[1]}<extra></extra>", |
| | hoverlabel=dict(namelength=0) |
| | ) |
| |
|
| | |
| | edge_trace = go.Scatter( |
| | x=edge_x, y=edge_y, |
| | line=dict(width=0.5, color='grey'), |
| | hoverinfo='none', |
| | mode='lines' |
| | ) |
| |
|
| | |
| | fig = go.Figure(data=[edge_trace, node_trace], |
| | layout=go.Layout( |
| | width=1200, |
| | height=400, |
| | paper_bgcolor='white', |
| | plot_bgcolor='white', |
| | showlegend=False, |
| | margin=dict(l=0, r=0, t=0, b=0), |
| | xaxis=dict( |
| | showgrid=False, |
| | zeroline=False, |
| | showticklabels=False, |
| | ), |
| | yaxis=dict( |
| | showgrid=False, |
| | zeroline=False, |
| | showticklabels=False, |
| | scaleanchor="x", |
| | scaleratio=1 |
| | ) |
| | )) |
| | fig.show() |
| |
|
| | fig.write_html(r"src\fragments\token_visualization.html", |
| | include_plotlyjs=False, |
| | full_html=False, |
| | config={ |
| | 'displayModeBar': False, |
| | 'responsive': True, |
| | 'scrollZoom': False, |
| | }) |
| |
|
| | ... |