Spaces:
Running
Running
feat: Initial commit
Browse files- requirements.txt +2 -1
- src/streamlit_app.py +111 -29
requirements.txt
CHANGED
|
@@ -1,3 +1,4 @@
|
|
| 1 |
altair
|
| 2 |
pandas
|
| 3 |
-
streamlit
|
|
|
|
|
|
| 1 |
altair
|
| 2 |
pandas
|
| 3 |
+
streamlit
|
| 4 |
+
streamlit_vertical_slider
|
src/streamlit_app.py
CHANGED
|
@@ -1,40 +1,122 @@
|
|
| 1 |
import altair as alt
|
| 2 |
-
import numpy as np
|
| 3 |
import pandas as pd
|
| 4 |
import streamlit as st
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
-
""
|
| 7 |
-
# Welcome to Streamlit!
|
| 8 |
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
-
|
| 14 |
-
""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
-
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
radius = indices
|
| 22 |
|
| 23 |
-
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
st.altair_chart(alt.Chart(df, height=700, width=700)
|
| 34 |
-
.mark_point(filled=True)
|
| 35 |
.encode(
|
| 36 |
-
x=alt.X("
|
| 37 |
-
y=alt.Y("
|
| 38 |
-
color=alt.
|
| 39 |
-
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import altair as alt
|
|
|
|
| 2 |
import pandas as pd
|
| 3 |
import streamlit as st
|
| 4 |
+
import streamlit_vertical_slider as svs
|
| 5 |
+
import torch
|
| 6 |
+
from streamlit_vertical_slider import vertical_slider
|
| 7 |
|
| 8 |
+
st.title("Number Token Loss - Demo")
|
|
|
|
| 9 |
|
| 10 |
+
st.markdown("""
|
| 11 |
+
Adjust the sliders to set a predicted probability for each token (0-9 and "Text").
|
| 12 |
+
The sliders are vertical and compact. The app normalizes the slider values
|
| 13 |
+
to form a valid probability distribution, visualizes it, and computes the corresponding
|
| 14 |
+
Cross Entropy, NTL-MSE, and NTL-WAS losses.
|
| 15 |
+
""")
|
| 16 |
|
| 17 |
+
# Vertical sliders for predicted probabilities of tokens 0-9 and "Text"
|
| 18 |
+
st.markdown("#### Predicted Token Probabilities")
|
| 19 |
+
cols = st.columns(11)
|
| 20 |
+
prob_values = []
|
| 21 |
+
for i, col in enumerate(cols):
|
| 22 |
+
label = f"Token {i}" if i < 10 else "Text"
|
| 23 |
+
with col:
|
| 24 |
+
val = svs.vertical_slider(
|
| 25 |
+
label=label,
|
| 26 |
+
min_value=0.0,
|
| 27 |
+
max_value=1.0,
|
| 28 |
+
step=0.1,
|
| 29 |
+
height=50,
|
| 30 |
+
key=f"slider_{i}",
|
| 31 |
+
slider_color="green",
|
| 32 |
+
track_color="lightgray",
|
| 33 |
+
thumb_color="black",
|
| 34 |
+
)
|
| 35 |
+
prob_values.append(val)
|
| 36 |
|
| 37 |
+
# Normalize the probabilities to sum to 1
|
| 38 |
+
total = sum(prob_values)
|
| 39 |
+
probs = (
|
| 40 |
+
torch.ones(11) / 11.0
|
| 41 |
+
if total == 0
|
| 42 |
+
else torch.tensor([v / total for v in prob_values])
|
| 43 |
+
)
|
| 44 |
|
| 45 |
+
# Token labels
|
| 46 |
+
options = [str(i) for i in range(10)] + ["Text"]
|
|
|
|
| 47 |
|
| 48 |
+
# Ground truth token selection
|
| 49 |
+
gt_choice = st.selectbox("Ground Truth Token", options=options, index=0)
|
| 50 |
+
if gt_choice == "Text":
|
| 51 |
+
gt_index = 10
|
| 52 |
+
gt_numeric = None
|
| 53 |
+
else:
|
| 54 |
+
gt_index = int(gt_choice)
|
| 55 |
+
gt_numeric = gt_index
|
| 56 |
|
| 57 |
+
# Visualize the input distribution with highlighted ground truth bar
|
| 58 |
+
st.markdown("#### Input Probability Distribution")
|
| 59 |
+
df_dist = pd.DataFrame({"token": options, "probability": probs.numpy()})
|
| 60 |
+
chart = (
|
| 61 |
+
alt.Chart(df_dist)
|
| 62 |
+
.mark_bar()
|
|
|
|
|
|
|
|
|
|
| 63 |
.encode(
|
| 64 |
+
x=alt.X("token:N", title="Token"),
|
| 65 |
+
y=alt.Y("probability:Q", title="Probability", scale=alt.Scale(domain=[0, 1])),
|
| 66 |
+
color=alt.condition(
|
| 67 |
+
alt.datum.token == gt_choice,
|
| 68 |
+
alt.value("green"), # Highlight ground truth token
|
| 69 |
+
alt.value("steelblue"), # Other tokens
|
| 70 |
+
),
|
| 71 |
+
)
|
| 72 |
+
.properties(height=300)
|
| 73 |
+
)
|
| 74 |
+
st.altair_chart(chart, use_container_width=True)
|
| 75 |
+
|
| 76 |
+
# Compute Cross Entropy loss: -log(predicted probability of the ground truth)
|
| 77 |
+
ce_loss = -torch.log(torch.clamp(probs[gt_index], min=1e-9))
|
| 78 |
+
|
| 79 |
+
# Compute NTL-MSE loss
|
| 80 |
+
if gt_numeric is None:
|
| 81 |
+
ntl_mse_loss = torch.tensor(0.0)
|
| 82 |
+
else:
|
| 83 |
+
numeric_probs = probs[:10]
|
| 84 |
+
values = torch.arange(0, 10, dtype=torch.float32)
|
| 85 |
+
pred_value = torch.sum(numeric_probs * values)
|
| 86 |
+
ntl_mse_loss = (pred_value - float(gt_numeric)) ** 2
|
| 87 |
+
|
| 88 |
+
# Compute NTL-WAS loss
|
| 89 |
+
if gt_numeric is None:
|
| 90 |
+
ntl_was_loss = torch.tensor(0.0)
|
| 91 |
+
else:
|
| 92 |
+
numeric_probs = probs[:10]
|
| 93 |
+
values = torch.arange(0, 10, dtype=torch.float32)
|
| 94 |
+
abs_diff = torch.abs(values - float(gt_numeric))
|
| 95 |
+
ntl_was_loss = torch.sum(numeric_probs * abs_diff)
|
| 96 |
+
|
| 97 |
+
# Convert losses to Python floats and round to 3 decimals
|
| 98 |
+
ce_val = round(ce_loss.item(), 3)
|
| 99 |
+
mse_val = round(ntl_mse_loss.item(), 3)
|
| 100 |
+
was_val = round(ntl_was_loss.item(), 3)
|
| 101 |
+
|
| 102 |
+
# Display numeric values of the losses
|
| 103 |
+
st.subheader("Loss Values")
|
| 104 |
+
st.write(f"**Cross Entropy:** {ce_val:.3f}")
|
| 105 |
+
st.write(f"**NTL-MSE:** {mse_val:.3f}")
|
| 106 |
+
st.write(f"**NTL-WAS:** {was_val:.3f}")
|
| 107 |
+
|
| 108 |
+
# Bar chart comparison of the three losses
|
| 109 |
+
st.subheader("Loss Comparison Chart")
|
| 110 |
+
loss_df = pd.DataFrame(
|
| 111 |
+
{
|
| 112 |
+
"Loss": ["Cross Entropy", "NTL-MSE", "NTL-WAS"],
|
| 113 |
+
"Value": [ce_val, mse_val, was_val],
|
| 114 |
+
}
|
| 115 |
+
).set_index("Loss")
|
| 116 |
+
st.bar_chart(loss_df)
|
| 117 |
+
|
| 118 |
+
# References / resources section with links
|
| 119 |
+
st.markdown("### Resources")
|
| 120 |
+
st.markdown(
|
| 121 |
+
"- **Paper:** [Regress, Don't Guess – A Regression-like Loss on Number Tokens for Language Models](https://arxiv.org/abs/2411.02083) \n- **Code:** [tum-ai/number-token-loss (GitHub)](https://github.com/tum-ai/number-token-loss)"
|
| 122 |
+
)
|