Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -18,7 +18,7 @@ CONFIG.set_default_api_key(api_key)
|
|
| 18 |
access_token = os.environ['HUGGING_FACE_HUB_TOKEN']
|
| 19 |
|
| 20 |
# Load the Language Model
|
| 21 |
-
llama = LanguageModel("meta-llama/Meta-Llama-3.1-8B"
|
| 22 |
|
| 23 |
#placeholder for reset
|
| 24 |
prompts_with_probs = pd.DataFrame(
|
|
@@ -55,7 +55,9 @@ def run_lens(model,PROMPT):
|
|
| 55 |
logits_lens_token_result_by_layer.append(logits_lens_next_token)
|
| 56 |
tokens_out = llama.lm_head.output.argmax(dim=-1).save()
|
| 57 |
expected_token = tokens_out[0][-1].save()
|
|
|
|
| 58 |
logits_lens_all_probs = np.concatenate([probs[:, expected_token].cpu().detach().to(torch.float32).numpy() for probs in logits_lens_probs_by_layer])
|
|
|
|
| 59 |
#get the rank of the expected token from each layer's distribution
|
| 60 |
for layer_probs in logits_lens_probs_by_layer:
|
| 61 |
# Sort the probabilities in descending order and find the rank of the expected token
|
|
@@ -113,7 +115,7 @@ def plot_prob(prompts_with_probs):
|
|
| 113 |
# Add labels and title
|
| 114 |
plt.xlabel('Layer Number')
|
| 115 |
plt.ylabel('Probability of Expected Token')
|
| 116 |
-
plt.title('Prob of expected token across layers\n(annotated with decoded output at each layer)')
|
| 117 |
plt.grid(True)
|
| 118 |
plt.ylim(0.0, 1.0)
|
| 119 |
plt.legend(title='Prompts', bbox_to_anchor=(0.5, -0.15), loc='upper center', ncol=1)
|
|
@@ -177,6 +179,8 @@ def plot_prob_mean(prompts_with_probs):
|
|
| 177 |
plt.title('Mean Probability of Expected Token')
|
| 178 |
plt.xticks(rotation=45, ha='right')
|
| 179 |
plt.grid(axis='y')
|
|
|
|
|
|
|
| 180 |
|
| 181 |
# Annotate the mean and variance on the bars
|
| 182 |
for bar, mean, var in zip(bars, summary_stats['mean_prob'], summary_stats['variance']):
|
|
@@ -277,18 +281,25 @@ def submit_prompts(prompts_data):
|
|
| 277 |
|
| 278 |
def clear_all(prompts):
|
| 279 |
prompts=[['']]
|
|
|
|
|
|
|
| 280 |
prompts_data = gr.Dataframe(headers=["Prompt"], row_count=5, col_count=1, value= prompts, type="array", interactive=True)
|
| 281 |
-
return prompts_data,plot_prob(prompts_with_probs),plot_rank(prompts_with_ranks),plot_prob_mean(prompts_with_probs),plot_rank_mean(prompts_with_ranks)
|
| 282 |
|
| 283 |
|
| 284 |
def gradio_interface():
|
| 285 |
with gr.Blocks(theme="gradio/monochrome") as demo:
|
| 286 |
prompts=[['']]
|
| 287 |
-
|
| 288 |
-
|
|
|
|
|
|
|
|
|
|
| 289 |
prompt_file.upload(process_file, inputs=[prompts_data,prompt_file], outputs=[prompts_data])
|
| 290 |
-
|
| 291 |
# Define the outputs
|
|
|
|
|
|
|
|
|
|
| 292 |
with gr.Row():
|
| 293 |
prob_visualization = gr.Image(value=plot_prob(prompts_with_probs), type="pil",label=" ")
|
| 294 |
rank_visualization = gr.Image(value=plot_rank(prompts_with_ranks), type="pil",label=" ")
|
|
@@ -296,14 +307,11 @@ def gradio_interface():
|
|
| 296 |
prob_mean_visualization = gr.Image(value=plot_prob_mean(prompts_with_probs), type="pil",label=" ")
|
| 297 |
rank_mean_visualization = gr.Image(value=plot_rank_mean(prompts_with_ranks), type="pil",label=" ")
|
| 298 |
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
submit_btn = gr.Button("Submit")
|
| 303 |
-
submit_btn.click(submit_prompts, inputs=[prompts_data], outputs=[prob_visualization,rank_visualization,prob_mean_visualization,rank_mean_visualization])#
|
| 304 |
|
| 305 |
|
| 306 |
demo.launch()
|
| 307 |
|
| 308 |
-
|
| 309 |
gradio_interface()
|
|
|
|
| 18 |
access_token = os.environ['HUGGING_FACE_HUB_TOKEN']
|
| 19 |
|
| 20 |
# Load the Language Model
|
| 21 |
+
llama = LanguageModel("meta-llama/Meta-Llama-3.1-8B")
|
| 22 |
|
| 23 |
#placeholder for reset
|
| 24 |
prompts_with_probs = pd.DataFrame(
|
|
|
|
| 55 |
logits_lens_token_result_by_layer.append(logits_lens_next_token)
|
| 56 |
tokens_out = llama.lm_head.output.argmax(dim=-1).save()
|
| 57 |
expected_token = tokens_out[0][-1].save()
|
| 58 |
+
# logits_lens_all_probs = np.concatenate([probs[:, expected_token].cpu().detach().numpy() for probs in logits_lens_probs_by_layer])
|
| 59 |
logits_lens_all_probs = np.concatenate([probs[:, expected_token].cpu().detach().to(torch.float32).numpy() for probs in logits_lens_probs_by_layer])
|
| 60 |
+
|
| 61 |
#get the rank of the expected token from each layer's distribution
|
| 62 |
for layer_probs in logits_lens_probs_by_layer:
|
| 63 |
# Sort the probabilities in descending order and find the rank of the expected token
|
|
|
|
| 115 |
# Add labels and title
|
| 116 |
plt.xlabel('Layer Number')
|
| 117 |
plt.ylabel('Probability of Expected Token')
|
| 118 |
+
plt.title('Prob of expected token across layers\n(annotated with actual decoded output at each layer)')
|
| 119 |
plt.grid(True)
|
| 120 |
plt.ylim(0.0, 1.0)
|
| 121 |
plt.legend(title='Prompts', bbox_to_anchor=(0.5, -0.15), loc='upper center', ncol=1)
|
|
|
|
| 179 |
plt.title('Mean Probability of Expected Token')
|
| 180 |
plt.xticks(rotation=45, ha='right')
|
| 181 |
plt.grid(axis='y')
|
| 182 |
+
plt.ylim(0, 1)
|
| 183 |
+
|
| 184 |
|
| 185 |
# Annotate the mean and variance on the bars
|
| 186 |
for bar, mean, var in zip(bars, summary_stats['mean_prob'], summary_stats['variance']):
|
|
|
|
| 281 |
|
| 282 |
def clear_all(prompts):
|
| 283 |
prompts=[['']]
|
| 284 |
+
# prompt_file=gr.File(type="filepath", label="Upload a File with Prompts")
|
| 285 |
+
prompt_file = None
|
| 286 |
prompts_data = gr.Dataframe(headers=["Prompt"], row_count=5, col_count=1, value= prompts, type="array", interactive=True)
|
| 287 |
+
return prompts_data,prompt_file,plot_prob(prompts_with_probs),plot_rank(prompts_with_ranks),plot_prob_mean(prompts_with_probs),plot_rank_mean(prompts_with_ranks)
|
| 288 |
|
| 289 |
|
| 290 |
def gradio_interface():
|
| 291 |
with gr.Blocks(theme="gradio/monochrome") as demo:
|
| 292 |
prompts=[['']]
|
| 293 |
+
with gr.Row():
|
| 294 |
+
with gr.Column(scale=3):
|
| 295 |
+
prompts_data = gr.Dataframe(headers=["Prompt"], row_count=5, col_count=1, value= prompts, type="array", interactive=True)
|
| 296 |
+
with gr.Column(scale=1):
|
| 297 |
+
prompt_file=gr.File(type="filepath", label="Upload a File with Prompts")
|
| 298 |
prompt_file.upload(process_file, inputs=[prompts_data,prompt_file], outputs=[prompts_data])
|
|
|
|
| 299 |
# Define the outputs
|
| 300 |
+
with gr.Row():
|
| 301 |
+
clear_btn = gr.Button("Clear")
|
| 302 |
+
submit_btn = gr.Button("Submit")
|
| 303 |
with gr.Row():
|
| 304 |
prob_visualization = gr.Image(value=plot_prob(prompts_with_probs), type="pil",label=" ")
|
| 305 |
rank_visualization = gr.Image(value=plot_rank(prompts_with_ranks), type="pil",label=" ")
|
|
|
|
| 307 |
prob_mean_visualization = gr.Image(value=plot_prob_mean(prompts_with_probs), type="pil",label=" ")
|
| 308 |
rank_mean_visualization = gr.Image(value=plot_rank_mean(prompts_with_ranks), type="pil",label=" ")
|
| 309 |
|
| 310 |
+
clear_btn.click(clear_all, inputs=[prompts_data], outputs=[prompts_data,prompt_file,prob_visualization,rank_visualization,prob_mean_visualization,rank_mean_visualization])
|
| 311 |
+
submit_btn.click(submit_prompts, inputs=[prompts_data], outputs=[prob_visualization,rank_visualization,prob_mean_visualization,rank_mean_visualization])#
|
| 312 |
+
prompt_file.clear(clear_all, inputs=[prompts_data], outputs=[prompts_data,prompt_file,prob_visualization,rank_visualization,prob_mean_visualization,rank_mean_visualization])
|
|
|
|
|
|
|
| 313 |
|
| 314 |
|
| 315 |
demo.launch()
|
| 316 |
|
|
|
|
| 317 |
gradio_interface()
|