Update app.py
Browse files
app.py
CHANGED
|
@@ -10,14 +10,15 @@ st.set_page_config(page_title="Transflower ๐ธ", page_icon="๐ผ", layout="cent
|
|
| 10 |
|
| 11 |
st.markdown(
|
| 12 |
"<h1 style='text-align: center; color: pink;'>๐ธ Transflower ๐ธ</h1>"
|
| 13 |
-
|
| 14 |
unsafe_allow_html=True,
|
| 15 |
)
|
| 16 |
|
| 17 |
# Load model and tokenizer
|
| 18 |
model_name = "t5-small"
|
| 19 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 20 |
-
model = AutoModelForSeq2SeqLM.from_pretrained(model_name
|
|
|
|
| 21 |
|
| 22 |
# Input area
|
| 23 |
user_input = st.text_area("๐ผ Enter text to summarize or visualize:", height=200)
|
|
@@ -26,24 +27,25 @@ if st.button("โจ Visualize Transformer Magic โจ"):
|
|
| 26 |
if not user_input.strip():
|
| 27 |
st.warning("Please enter some text to visualize.")
|
| 28 |
else:
|
| 29 |
-
#
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
# Forward pass
|
| 33 |
with torch.no_grad():
|
| 34 |
-
|
| 35 |
-
|
| 36 |
|
| 37 |
-
|
| 38 |
-
|
|
|
|
|
|
|
| 39 |
|
| 40 |
-
|
| 41 |
-
st.
|
| 42 |
-
fig, ax = plt.subplots(figsize=(10, 5))
|
| 43 |
|
| 44 |
-
|
| 45 |
-
attention_data = outputs.attentions[-1] # List of attention tensors from each layer
|
| 46 |
-
avg_attention = attention_data[0].mean(dim=0).squeeze().detach().numpy() # mean over heads
|
| 47 |
|
| 48 |
-
|
|
|
|
|
|
|
| 49 |
st.pyplot(fig)
|
|
|
|
| 10 |
|
| 11 |
st.markdown(
|
| 12 |
"<h1 style='text-align: center; color: pink;'>๐ธ Transflower ๐ธ</h1>"
|
| 13 |
+
|
| 14 |
unsafe_allow_html=True,
|
| 15 |
)
|
| 16 |
|
| 17 |
# Load model and tokenizer
|
| 18 |
model_name = "t5-small"
|
| 19 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 20 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
| 21 |
+
model.eval()
|
| 22 |
|
| 23 |
# Input area
|
| 24 |
user_input = st.text_area("๐ผ Enter text to summarize or visualize:", height=200)
|
|
|
|
| 27 |
if not user_input.strip():
|
| 28 |
st.warning("Please enter some text to visualize.")
|
| 29 |
else:
|
| 30 |
+
# Encode input
|
| 31 |
+
inputs = tokenizer("summarize: " + user_input, return_tensors="pt", truncation=True)
|
| 32 |
+
|
| 33 |
+
# Forward pass manually to get attention
|
| 34 |
with torch.no_grad():
|
| 35 |
+
encoder_outputs = model.encoder(**inputs, output_attentions=True, return_dict=True)
|
| 36 |
+
attention = encoder_outputs.attentions[-1][0].mean(dim=0).detach().numpy()
|
| 37 |
|
| 38 |
+
# Generate summary
|
| 39 |
+
with torch.no_grad():
|
| 40 |
+
summary_ids = model.generate(inputs["input_ids"], max_length=50)
|
| 41 |
+
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
|
| 42 |
|
| 43 |
+
st.subheader("๐ธ Summary:")
|
| 44 |
+
st.success(summary)
|
|
|
|
| 45 |
|
| 46 |
+
st.subheader("๐ Encoder Attention Heatmap:")
|
|
|
|
|
|
|
| 47 |
|
| 48 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
| 49 |
+
sns.heatmap(attention, cmap="YlGnBu", ax=ax)
|
| 50 |
+
ax.set_title("Encoder Self-Attention Heatmap ๐ซ")
|
| 51 |
st.pyplot(fig)
|