llm-thinking / utils.py
shftan's picture
Clean up code
af4a860
raw
history blame
2.89 kB
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.colors as clrs
import requests
import json
import pandas as pd
import torch
# Function to get tokens given text
def get_tokens(tokenizer, text):
token_ids = tokenizer.encode(text, return_tensors="pt", add_special_tokens=False).to("cuda")
tokens = tokenizer.convert_ids_to_tokens(token_ids[0])
return tokens, token_ids
# Function to apply chat template to prompt
def decorate_prompt(tokenizer, prompt):
chat = [
{"role": "user", "content": prompt},
{"role": "assistant", "content": ""},
]
text = tokenizer.apply_chat_template(chat, tokenize=False)
token_ids = tokenizer.encode(text, return_tensors="pt", add_special_tokens=False).to("cuda")
return token_ids
# Function to get response to prompt
def get_response(model_pipe, prompt):
response = model_pipe(prompt)[0]['generated_text']
return response
# Function to highlight tokens based on given values
def plot_tokens_with_highlights(tokens, values, concept, cmap_name='Oranges', vmin=None, vmax=None):
if len(tokens) != len(values):
raise ValueError("The number of tokens and values must be the same.")
# Set color map
cmap = cm.get_cmap(cmap_name)
norm = clrs.Normalize(vmin=vmin if vmin is not None else values.detach().min(),
vmax=vmax if vmax is not None else values.detach().max())
html_output = f"<h3>How much information about the concept '{concept}' is carried in each token:</h3>"
for token, value in zip(tokens, values.detach().numpy()):
rgba_color = cmap(norm(value))
hex_color = '#%02x%02x%02x' % (int(rgba_color[0]*255), int(rgba_color[1]*255), int(rgba_color[2]*255))
html_output += f'<span style="background-color: {hex_color};" title="{value:.4f}">{token}</span> '
return html_output
# Function to get concepts dictionary
def get_concepts_dictionary(dictionary_url):
response = requests.get(dictionary_url, stream=True)
response.raise_for_status()
data_dict = {}
for line in response.iter_lines(decode_unicode=True):
if line:
obj = json.loads(line)
concept_id = obj.get("concept_id")
concept = obj.get("concept")
if concept_id and concept:
data_dict[concept_id] = concept.capitalize()
return data_dict
# Function to get matching concepts
def select_concepts(all_concepts, desired_concept):
concept_ids = []
for k, v in all_concepts.items():
if desired_concept.lower() in v.lower():
concept_ids.append(k)
concept_data = []
for concept_id in concept_ids:
concept_name = all_concepts.get(concept_id, "Unknown Concept")
concept_data.append({"Concept ID": concept_id, "Concept Name": concept_name})
concept_df = pd.DataFrame(concept_data)
return torch.tensor(concept_ids), concept_df