import streamlit as st
from transformers import DistilBertTokenizerFast, DistilBertModel
import torch
import torch.nn as nn
from encoder_layer import Encoder_block
import pandas as pd
from nlp_functions import extract_keyphrases, get_attention_weights, plot_attention_heatmap, predict_ner_tags, render_ner_html, merge_subwords_and_bio
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
# ---------------------
# Load Model Components
# ---------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")
bert_emb = DistilBertModel.from_pretrained("distilbert-base-uncased").get_input_embeddings().to(device)
d_model = 512
num_heads = 8
d_ff = 2048
projection = nn.Linear(768, 512)
encoder = Encoder_block(d_model=d_model, num_heads=num_heads, d_ff=d_ff)
classifier = nn.Linear(512, 3)
ckpt = torch.load("restored_model.pt", map_location=device)
projection.load_state_dict(ckpt["projection_state_dict"])
encoder.load_state_dict(ckpt["encoder_state_dict"])
classifier.load_state_dict(ckpt["classifier_state_dict"])
projection.to(device)
encoder.to(device)
classifier.to(device)
encoder.eval()
classifier.eval()
# ---------------------
# Streamlit Frontend
# ---------------------
feature = st.sidebar.radio(options=["Extract Keywords", "Attention Heatmap Viewer", "NER Demo"], label="Features")
st.sidebar.markdown("")
st.sidebar.image("encoder2.png", width=200)
if feature == "Extract Keywords":
# st.title("🔍 Keyword Extraction using Custom Transformer Encoder")
st.markdown("
Keyword Extraction using Custom Transformer Encoder
", unsafe_allow_html=True)
# st.write("Enter text below and get extracted keyphrases:")
st.markdown("Enter text below and get extracted keyphrases:
", unsafe_allow_html=True)
user_text = st.text_area("Input Text", height=100)
# ---------------------
# Keyword Extraction
# ---------------------
if st.button("Extract Keywords"):
if not user_text.strip():
st.warning("Please enter some text!")
else:
with st.spinner("Extracting keywords..."):
keyphrases = extract_keyphrases(user_text)
# st.subheader("✨ Extracted Keyphrases")
st.markdown("✨ Extracted Keyphrases
", unsafe_allow_html=True)
if len(keyphrases) == 0:
st.info("No meaningful keyphrases detected.")
else:
# for kp in keyphrases:
# st.success(kp)
html = ""
for kp in keyphrases:
html += f"
" \
f"{kp}
"
html += "
"
st.markdown(html, unsafe_allow_html=True)
# feature = st.sidebar.selectbox("Choose Feature", ["Keyword Extraction", "Attention Heatmap Viewer"])
# if feature == "Attention Heatmap Viewer":
# # st.title("Attention Heatmap Viewer")
# st.markdown("Attention Heatmap Viewer
", unsafe_allow_html=True)
# sentence = st.text_area("Enter a sentence:", value="The quick brown fox jumps over the lazy dog")
# if sentence.strip() != "":
# if "last_text" not in st.session_state or st.session_state["last_text"] != sentence:
# st.session_state["last_text"] = sentence
# st.session_state["attn_cache"], st.session_state["tokens_cache"] = get_attention_weights(sentence)
# attn_weights = st.session_state["attn_cache"]
# tokens = st.session_state["tokens_cache"]
# num_heads1 = attn_weights.shape[1]
# head_display_options = list(range(1, num_heads1 + 1))
# head_display = st.selectbox(
# "Select an attention head to visualize:",
# options=head_display_options,
# index=0
# )
# head = head_display - 1
# attn = attn_weights[0, head].cpu().numpy()
# unique_tokens = [f"{tok}_{i}" for i, tok in enumerate(tokens)]
# plot_attention_heatmap(attn, tokens, head)
# st.markdown(f"**Raw Attention Scores - Head {head}:**")
# df_attn = pd.DataFrame(attn, columns=unique_tokens, index=unique_tokens)
# st.dataframe(df_attn.style.format("{:.7f}"))
# if feature == "Attention Heatmap Viewer":
# # st.title("Attention Heatmap Viewer")
# st.markdown("Attention Heatmap Viewer
", unsafe_allow_html=True)
# if "attn_cache" not in st.session_state:
# st.session_state.attn_cache = None
# if "tokens_cache" not in st.session_state:
# st.session_state.tokens_cache = None
# if "last_text" not in st.session_state:
# st.session_state.last_text = ""
# col1, col2 = st.columns([4, 1])
# with col1:
# with st.form("input_form"):
# sentence = st.text_area("Enter a sentence:", value=st.session_state.last_text)
# submitted = st.form_submit_button("See HeatMap")
# if submitted and sentence.strip() != "":
# st.session_state.last_text = sentence
# st.session_state.attn_cache, st.session_state.tokens_cache = get_attention_weights(sentence)
# if st.session_state.attn_cache is not None:
# attn_weights = st.session_state.attn_cache
# tokens = st.session_state.tokens_cache
# num_heads = attn_weights.shape[1]
# head_display_options = list(range(1, num_heads + 1))
# with col2:
# head_display = st.selectbox(
# "Select an attention head to visualize:",
# options=head_display_options,
# index=0
# )
# head = head_display - 1
# attn = attn_weights[0, head].cpu().numpy()
# unique_tokens = [f"{tok}_{i}" for i, tok in enumerate(tokens)]
# plot_attention_heatmap(attn, tokens, head_display)
# st.markdown(f"**Raw Attention Scores - Head {head_display}:**")
# import pandas as pd
# df_attn = pd.DataFrame(attn, columns=unique_tokens, index=unique_tokens)
# st.dataframe(df_attn.style.format("{:.7f}"))
if feature == "Attention Heatmap Viewer":
st.markdown("Attention Heatmap Viewer
", unsafe_allow_html=True)
# Initialize session state
if "attn_cache" not in st.session_state:
st.session_state.attn_cache = None
if "tokens_cache" not in st.session_state:
st.session_state.tokens_cache = None
if "last_text" not in st.session_state:
st.session_state.last_text = ""
col1, col2 = st.columns([4, 1])
with col1:
with st.form("input_form"):
sentence = st.text_area("Enter a sentence:", value=st.session_state.last_text, height=100)
submitted = st.form_submit_button("See HeatMap")
# Process input when submitted
if submitted and sentence.strip():
with st.spinner("Computing attention weights..."):
st.session_state.last_text = sentence
st.session_state.attn_cache, st.session_state.tokens_cache = get_attention_weights(sentence)
st.success("Attention weights computed!")
# Only show visualization if we have attention data
if st.session_state.attn_cache is not None:
attn_weights = st.session_state.attn_cache
tokens = st.session_state.tokens_cache
num_heads = attn_weights.shape[1]
head_options = list(range(1, num_heads + 1)) # Define here so it's always available when data exists
with col2:
head_display = st.selectbox(
"Select an attention head to visualize:",
options=head_options,
index=0,
key="head_selector" # Prevents conflicts
)
head_idx = head_display - 1
attn = attn_weights[0, head_idx].cpu().numpy()
# Make tokens unique for display (especially important for repeated tokens)
unique_tokens = [f"{tok}_{i}" for i, tok in enumerate(tokens)]
display_tokens = tokens # For clean visualization
# Plot heatmap
plot_attention_heatmap(attn, display_tokens, head_display)
# Show raw attention matrix
st.markdown(f"**Raw Attention Scores - Head {head_display}:**")
import pandas as pd
df_attn = pd.DataFrame(attn, columns=unique_tokens, index=unique_tokens)
st.dataframe(df_attn.style.format("{:.6f}").background_gradient(cmap='viridis'))
else:
# Show a placeholder when no data yet
with col2:
st.selectbox(
"Select an attention head to visualize:",
options=[1], # dummy
disabled=True
)
st.info("👈 Enter a sentence and click 'See HeatMap' to visualize attention!")
# if feature == "NER Demo":
# # st.title("NER Demo with Custom Transformer Encoder")
# st.markdown("NER Demo with Custom Transformer Encoder
", unsafe_allow_html=True)
# user_text = st.text_area("Enter your sentence:")
# if st.button("Predict NER Tags"):
# if not user_text.strip():
# st.warning("Please enter some text!")
# else:
# tokens, preds = predict_ner_tags(user_text)
# ner_html = format_ner_html(tokens, preds)
# # st.markdown("### Predicted Named Entities")
# st.markdown("Predicted Named Entities
", unsafe_allow_html=True)
# st.markdown(ner_html, unsafe_allow_html=True)
# # Optionally show token-tag pairs
# tag_map = {0: "O", 1: "B", 2: "I"}
# tag_pairs = [(t.replace("##", ""), tag_map[p]) for t, p in zip(tokens, preds) if t not in ("[CLS]", "[SEP]")]
# st.table(tag_pairs)
if feature == "NER Demo":
st.markdown("NER Demo with Custom Transformer Encoder
", unsafe_allow_html=True)
# st.markdown("Enter a sentence below and see **word-level** named entity recognition with clean highlighting.")
st.markdown("Enter a sentence below and see **word-level** named entity recognition with clean highlighting.
", unsafe_allow_html=True)
user_text = st.text_area("Enter your sentence:", height=120, placeholder="Example: Apple is opening its first store in Delhi, India next month.")
if st.button("Predict Named Entities", type="primary", use_container_width=True):
if not user_text.strip():
st.warning("Please enter some text!")
else:
with st.spinner("Running your custom NER model..."):
tokens, preds = predict_ner_tags(user_text)
words, tags = merge_subwords_and_bio(tokens, preds)
ner_html = render_ner_html(words, tags)
# Beautiful highlighted output
st.markdown("Predicted Entities
", unsafe_allow_html=True)
# FIXED: Fully responsive, wraps naturally, never goes off-screen
st.markdown(f"""
{ner_html}
""", unsafe_allow_html=True)
# Clean table
st.markdown("#### Token → NER Tag")
df = pd.DataFrame({
"Word": words,
"BIO Tag": tags,
"Entity": [tag[2:] if tag.startswith(("B-", "I-")) else "O" for tag in tags],
"Full Label": [tag if tag != "O" else "Outside" for tag in tags]
})
st.dataframe(df, use_container_width=True, hide_index=True)