Spaces:
Running
Running
visualization
Browse files- app.py +44 -21
- lm_steer/models/model_base.py +1 -1
app.py
CHANGED
|
@@ -63,10 +63,15 @@ def word_embedding_space_analysis(
|
|
| 63 |
return pd.DataFrame(
|
| 64 |
data,
|
| 65 |
columns=["One Direction", "Another Direction"],
|
| 66 |
-
index=[f"Dim
|
| 67 |
)
|
| 68 |
|
| 69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
def main():
|
| 71 |
# set up the page
|
| 72 |
random.seed(0)
|
|
@@ -103,17 +108,17 @@ def main():
|
|
| 103 |
# set up the model
|
| 104 |
st.divider()
|
| 105 |
st.divider()
|
| 106 |
-
st.subheader("Select
|
| 107 |
'''
|
| 108 |
Due to resource limits, we are only able to provide a few models for
|
| 109 |
steering. You can also refer to the Github repository:
|
| 110 |
-
https://github.com/Glaciohound/LM-Steer to host larger models.
|
| 111 |
Some generated texts may contain toxic or offensive content. Please be
|
| 112 |
cautious when using the generated texts.
|
| 113 |
Note that for these smaller models, the generation quality may not be as
|
| 114 |
good as the larger models (GPT-4, Llama, etc.).
|
| 115 |
'''
|
| 116 |
-
col1, col2 = st.columns(
|
| 117 |
model_name = col1.selectbox(
|
| 118 |
"Select a model to steer",
|
| 119 |
[
|
|
@@ -143,23 +148,23 @@ def main():
|
|
| 143 |
total_param = sum(p.numel() for _, p in model.named_parameters()) / \
|
| 144 |
1024 ** 2
|
| 145 |
ratio = num_param / total_param
|
| 146 |
-
|
| 147 |
-
|
|
|
|
| 148 |
|
| 149 |
# steering
|
| 150 |
steer_range = 3.
|
| 151 |
steer_interval = 0.2
|
| 152 |
-
st.subheader("Enter a sentence and steer the model")
|
| 153 |
st.session_state.prompt = st.text_input(
|
| 154 |
"Enter a prompt",
|
| 155 |
st.session_state.get("prompt", "My life")
|
| 156 |
)
|
| 157 |
col1, col2, col3 = st.columns([2, 2, 1], gap="medium")
|
| 158 |
sentiment = col1.slider(
|
| 159 |
-
"Sentiment (
|
| 160 |
-steer_range, steer_range, 0.0, steer_interval)
|
| 161 |
detoxification = col2.slider(
|
| 162 |
-
"Detoxification Strength (
|
| 163 |
-steer_range, steer_range, 0.0,
|
| 164 |
steer_interval)
|
| 165 |
max_length = col3.number_input("Max length", 20, 200, 20, 20)
|
|
@@ -191,7 +196,7 @@ def main():
|
|
| 191 |
# Analysing the sentence
|
| 192 |
st.divider()
|
| 193 |
st.divider()
|
| 194 |
-
st.subheader("
|
| 195 |
'''
|
| 196 |
LM-Steer also serves as a probe for analyzing the text. It can be used to
|
| 197 |
analyze the sentiment and detoxification of the text. Now, we proceed and
|
|
@@ -200,23 +205,25 @@ def main():
|
|
| 200 |
entangled, as a negative sentiment may also detoxify the text.
|
| 201 |
'''
|
| 202 |
if st.session_state.get("analyzed_text", "") != "" and \
|
| 203 |
-
st.button("Analyze the
|
| 204 |
col1, col2 = st.columns(2)
|
| 205 |
-
for name, col, dim, color in zip(
|
| 206 |
["Sentiment", "Detoxification"],
|
| 207 |
[col1, col2],
|
| 208 |
[2, 0],
|
| 209 |
["#ff7f0e", "#1f77b4"],
|
|
|
|
| 210 |
):
|
| 211 |
with st.spinner(f"Analyzing {name}..."):
|
| 212 |
col.subheader(name)
|
| 213 |
# classification
|
| 214 |
col.markdown(
|
| 215 |
"##### Sentence Classification Distribution")
|
|
|
|
| 216 |
_, dist_list, _ = model.steer_analysis(
|
| 217 |
st.session_state.analyzed_text,
|
| 218 |
dim, -steer_range, steer_range,
|
| 219 |
-
bins=
|
| 220 |
)
|
| 221 |
dist_list = np.array(dist_list)
|
| 222 |
col.bar_chart(
|
|
@@ -241,9 +248,7 @@ def main():
|
|
| 241 |
tokens = [f"{i:3d}: {tokenizer.decode([t])}"
|
| 242 |
for i, t in enumerate(tokens)]
|
| 243 |
col.markdown("##### Token's Evidence Score in the Dimension")
|
| 244 |
-
col.write(
|
| 245 |
-
"which aligns with sliding bar directions."
|
| 246 |
-
)
|
| 247 |
col.bar_chart(
|
| 248 |
pd.DataFrame(
|
| 249 |
{
|
|
@@ -256,23 +261,41 @@ def main():
|
|
| 256 |
|
| 257 |
st.divider()
|
| 258 |
st.divider()
|
| 259 |
-
st.subheader("
|
| 260 |
'''
|
| 261 |
LM-Steer provides a lens on how word embeddings correlate with LM word
|
| 262 |
embeddings: what word dimensions contribute to or contrast to a specific
|
| 263 |
style. This analysis can be used to understand the word embedding space
|
| 264 |
and how it steers the model's generation.
|
|
|
|
| 265 |
Note that due to the bidirectional nature of the embedding spaces, in each
|
| 266 |
-
dimension, sometimes only one side of the word embeddings
|
| 267 |
-
|
|
|
|
|
|
|
| 268 |
'''
|
| 269 |
for dimension in ["Sentiment", "Detoxification"]:
|
| 270 |
-
f'##### {dimension}
|
| 271 |
dim = 2 if dimension == "Sentiment" else 0
|
| 272 |
analysis_result = word_embedding_space_analysis(
|
| 273 |
model_name, dim)
|
| 274 |
with st.expander("Show the analysis results"):
|
| 275 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 276 |
|
| 277 |
|
| 278 |
if __name__ == "__main__":
|
|
|
|
| 63 |
return pd.DataFrame(
|
| 64 |
data,
|
| 65 |
columns=["One Direction", "Another Direction"],
|
| 66 |
+
index=[f"Dim#{_i}" for _i in range(n_dim)],
|
| 67 |
)
|
| 68 |
|
| 69 |
|
| 70 |
+
# rgb tuple to hex color
|
| 71 |
+
def rgb_to_hex(rgb):
|
| 72 |
+
return '#%02x%02x%02x' % rgb
|
| 73 |
+
|
| 74 |
+
|
| 75 |
def main():
|
| 76 |
# set up the page
|
| 77 |
random.seed(0)
|
|
|
|
| 108 |
# set up the model
|
| 109 |
st.divider()
|
| 110 |
st.divider()
|
| 111 |
+
st.subheader("Select A Model and Steer It")
|
| 112 |
'''
|
| 113 |
Due to resource limits, we are only able to provide a few models for
|
| 114 |
steering. You can also refer to the Github repository:
|
| 115 |
+
https://github.com/Glaciohound/LM-Steer to host larger models locally.
|
| 116 |
Some generated texts may contain toxic or offensive content. Please be
|
| 117 |
cautious when using the generated texts.
|
| 118 |
Note that for these smaller models, the generation quality may not be as
|
| 119 |
good as the larger models (GPT-4, Llama, etc.).
|
| 120 |
'''
|
| 121 |
+
col1, col2, col3, col4 = st.columns([3, 1, 1, 1])
|
| 122 |
model_name = col1.selectbox(
|
| 123 |
"Select a model to steer",
|
| 124 |
[
|
|
|
|
| 148 |
total_param = sum(p.numel() for _, p in model.named_parameters()) / \
|
| 149 |
1024 ** 2
|
| 150 |
ratio = num_param / total_param
|
| 151 |
+
col2.metric("Parameters Steered", f"{num_param:.1f}M")
|
| 152 |
+
col3.metric("LM Total Size", f"{total_param:.1f}M")
|
| 153 |
+
col4.metric("Steered Ratio", f"{ratio:.2%}")
|
| 154 |
|
| 155 |
# steering
|
| 156 |
steer_range = 3.
|
| 157 |
steer_interval = 0.2
|
|
|
|
| 158 |
st.session_state.prompt = st.text_input(
|
| 159 |
"Enter a prompt",
|
| 160 |
st.session_state.get("prompt", "My life")
|
| 161 |
)
|
| 162 |
col1, col2, col3 = st.columns([2, 2, 1], gap="medium")
|
| 163 |
sentiment = col1.slider(
|
| 164 |
+
"Sentiment (Negative ↔︎ Positive)",
|
| 165 |
-steer_range, steer_range, 0.0, steer_interval)
|
| 166 |
detoxification = col2.slider(
|
| 167 |
+
"Detoxification Strength (Toxic ↔︎ Clean)",
|
| 168 |
-steer_range, steer_range, 0.0,
|
| 169 |
steer_interval)
|
| 170 |
max_length = col3.number_input("Max length", 20, 200, 20, 20)
|
|
|
|
| 196 |
# Analysing the sentence
|
| 197 |
st.divider()
|
| 198 |
st.divider()
|
| 199 |
+
st.subheader("LM-Steer Converts LMs into Text Analyzers")
|
| 200 |
'''
|
| 201 |
LM-Steer also serves as a probe for analyzing the text. It can be used to
|
| 202 |
analyze the sentiment and detoxification of the text. Now, we proceed and
|
|
|
|
| 205 |
entangled, as a negative sentiment may also detoxify the text.
|
| 206 |
'''
|
| 207 |
if st.session_state.get("analyzed_text", "") != "" and \
|
| 208 |
+
st.button("Analyze the text above", type="primary"):
|
| 209 |
col1, col2 = st.columns(2)
|
| 210 |
+
for name, col, dim, color, axis_annotation in zip(
|
| 211 |
["Sentiment", "Detoxification"],
|
| 212 |
[col1, col2],
|
| 213 |
[2, 0],
|
| 214 |
["#ff7f0e", "#1f77b4"],
|
| 215 |
+
["Negative ↔︎ Positive", "Toxic ↔︎ Clean"]
|
| 216 |
):
|
| 217 |
with st.spinner(f"Analyzing {name}..."):
|
| 218 |
col.subheader(name)
|
| 219 |
# classification
|
| 220 |
col.markdown(
|
| 221 |
"##### Sentence Classification Distribution")
|
| 222 |
+
col.write(axis_annotation)
|
| 223 |
_, dist_list, _ = model.steer_analysis(
|
| 224 |
st.session_state.analyzed_text,
|
| 225 |
dim, -steer_range, steer_range,
|
| 226 |
+
bins=4*int(steer_range)+1,
|
| 227 |
)
|
| 228 |
dist_list = np.array(dist_list)
|
| 229 |
col.bar_chart(
|
|
|
|
| 248 |
tokens = [f"{i:3d}: {tokenizer.decode([t])}"
|
| 249 |
for i, t in enumerate(tokens)]
|
| 250 |
col.markdown("##### Token's Evidence Score in the Dimension")
|
| 251 |
+
col.write(axis_annotation)
|
|
|
|
|
|
|
| 252 |
col.bar_chart(
|
| 253 |
pd.DataFrame(
|
| 254 |
{
|
|
|
|
| 261 |
|
| 262 |
st.divider()
|
| 263 |
st.divider()
|
| 264 |
+
st.subheader("LM-Steer Unveils Word Embeddings Space")
|
| 265 |
'''
|
| 266 |
LM-Steer provides a lens on how word embeddings correlate with LM word
|
| 267 |
embeddings: what word dimensions contribute to or contrast to a specific
|
| 268 |
style. This analysis can be used to understand the word embedding space
|
| 269 |
and how it steers the model's generation.
|
| 270 |
+
|
| 271 |
Note that due to the bidirectional nature of the embedding spaces, in each
|
| 272 |
+
dimension, sometimes only one side of the word embeddings contributes
|
| 273 |
+
(has an impact on the style), while the other side, (resulting in negative
|
| 274 |
+
logits) has a negligible impact on the style. The table below shows both
|
| 275 |
+
sides of the word embeddings in each dimension.
|
| 276 |
'''
|
| 277 |
for dimension in ["Sentiment", "Detoxification"]:
|
| 278 |
+
f'##### {dimension} Word Dimensions'
|
| 279 |
dim = 2 if dimension == "Sentiment" else 0
|
| 280 |
analysis_result = word_embedding_space_analysis(
|
| 281 |
model_name, dim)
|
| 282 |
with st.expander("Show the analysis results"):
|
| 283 |
+
color_scale = 7
|
| 284 |
+
color_init = 230
|
| 285 |
+
st.table(analysis_result.style.apply(
|
| 286 |
+
lambda x: [
|
| 287 |
+
"background: " + rgb_to_hex(
|
| 288 |
+
(255,
|
| 289 |
+
color_init-(9-i)*color_scale,
|
| 290 |
+
color_init-(9-i)*color_scale)
|
| 291 |
+
if dimension == "Sentiment" else
|
| 292 |
+
(color_init-(9-i)*color_scale,
|
| 293 |
+
color_init-(9-i)*color_scale,
|
| 294 |
+
255)
|
| 295 |
+
)
|
| 296 |
+
for i in range(len(x))
|
| 297 |
+
]
|
| 298 |
+
))
|
| 299 |
|
| 300 |
|
| 301 |
if __name__ == "__main__":
|
lm_steer/models/model_base.py
CHANGED
|
@@ -109,7 +109,7 @@ class LMSteerBase(nn.Module):
|
|
| 109 |
)
|
| 110 |
loss_token = loss_token.reshape(bins + 1, length - 1)
|
| 111 |
loss = loss_token.mean(-1)[:-1]
|
| 112 |
-
dist = ((- loss + loss.mean()) *
|
| 113 |
dist_list = list(zip(
|
| 114 |
[
|
| 115 |
min_value + (max_value - min_value) / (bins - 1) * bin_i
|
|
|
|
| 109 |
)
|
| 110 |
loss_token = loss_token.reshape(bins + 1, length - 1)
|
| 111 |
loss = loss_token.mean(-1)[:-1]
|
| 112 |
+
dist = ((- loss + loss.mean()) * 10).softmax(0)
|
| 113 |
dist_list = list(zip(
|
| 114 |
[
|
| 115 |
min_value + (max_value - min_value) / (bins - 1) * bin_i
|