Spaces:
Running
Running
File size: 20,565 Bytes
11c72a2 90e22b0 11c72a2 48b808a 11c72a2 90e22b0 11c72a2 2706eb3 11c72a2 2706eb3 11c72a2 2706eb3 11c72a2 010c288 11c72a2 8a9090a 11c72a2 2706eb3 11c72a2 2706eb3 11c72a2 2706eb3 90e22b0 774f4d3 11c72a2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 | import streamlit as st
import plotly.graph_objects as go
import plotly.colors as pc
import sys
import os
import base64
import streamlit.components.v1 as components
import html
# Absolute path to the repo root (assuming `ui.py` is in /app)
REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
sys.path.append(REPO_ROOT)
ASSETS_DIR = os.path.join(REPO_ROOT, 'assets')
DATA_DIR = os.path.join(REPO_ROOT, 'data')
# Import functions from the backend
from backend.inference.process_beta import (
load_beta_matrix,
get_top_words_over_time,
load_time_labels
)
from backend.inference.word_selector import get_interesting_words, get_word_trend
from backend.inference.indexing_utils import load_index
from backend.inference.doc_retriever import (
load_length_stats,
get_yearly_counts_for_word,
deduplicate_docs,
get_all_documents_for_word_year,
highlight_words,
extract_snippet
)
from backend.llm_utils.summarizer import summarize_multiword_docs, ask_multiturn_followup
from backend.llm_utils.label_generator import get_topic_labels
from backend.llm.llm_router import get_llm, list_supported_models
from backend.llm_utils.token_utils import estimate_k_max_from_word_stats
def get_base64_image(image_path):
with open(image_path, "rb") as img_file:
return base64.b64encode(img_file.read()).decode()
# --- Page Configuration ---
st.set_page_config(
page_title="DTECT",
page_icon="π",
layout="wide",
menu_items={
'Get Help': 'https://github.com/AdhyaSuman/DTECT',
'Report a bug': "https://github.com/AdhyaSuman/DTECT/issues/new"
}
)
# Sidebar branding and repo link
st.sidebar.markdown(
"""
<div style="text-align: center;">
<a href="https://github.com/AdhyaSuman/DTECT" target="_blank">
<img src="data:image/png;base64,{}" width="180" style="margin-bottom: 18px;">
</a>
<hr style="margin-bottom: 0;">
</div>
""".format(get_base64_image(os.path.join(ASSETS_DIR, 'Logo_light.png'))),
unsafe_allow_html=True
)
# 1. Sidebar: Model and Dataset Selection
st.sidebar.title("Configuration")
AVAILABLE_MODELS = ["DTM", "DETM", "CFDTM"]
ENV_VAR_MAP = {
"OpenAI": "OPENAI_API_KEY",
"Anthropic": "ANTHROPIC_API_KEY",
"Gemini": "GEMINI_API_KEY"
}
def list_datasets(data_dir):
return sorted([
name for name in os.listdir(data_dir)
if os.path.isdir(os.path.join(data_dir, name))
])
with st.sidebar.expander("Select Dataset & Topic Model", expanded=True):
datasets = list_datasets(DATA_DIR)
selected_dataset = st.selectbox("Dataset", datasets, help="Choose an available dataset.")
selected_model = st.selectbox("Model", AVAILABLE_MODELS, help="Select topic model architecture.")
# Check if the dataset has changed and reset session state if it has.
if 'current_dataset' not in st.session_state or st.session_state.current_dataset != selected_dataset:
st.session_state.current_dataset = selected_dataset
# List all session state keys that depend on the dataset
keys_to_clear = [
"selected_words",
"interesting_words",
"word_counts_multiselect",
"collected_deduplicated_docs",
"summary",
"context_for_followup",
"followup_history"
]
for key in keys_to_clear:
if key in st.session_state:
del st.session_state[key]
# Rerun the script to apply the clean state
st.rerun()
# Resolve paths
dataset_path = os.path.join(DATA_DIR, selected_dataset)
model_path = os.path.join(dataset_path, selected_model)
vocab_path = os.path.join(dataset_path, "processed/vocab.txt")
time2id_path = os.path.join(dataset_path, "processed/time2id.txt")
length_stats_path = os.path.join(dataset_path, "processed/length_stats.json")
lemma_map_path = os.path.join(dataset_path, "processed/lemma_to_forms.json")
docs_path = os.path.join(dataset_path, "docs.jsonl")
index_path = os.path.join(dataset_path, "inverted_index.json")
beta_path = os.path.join(model_path, "beta.npy")
label_cache_path = os.path.join(model_path, "topic_label_cache.json")
with st.sidebar.expander("LLM Settings", expanded=True):
provider = st.selectbox("LLM Provider", options=list(ENV_VAR_MAP.keys()), help="Choose the LLM backend.")
available_models = list_supported_models(provider)
model = st.selectbox("LLM Model", options=available_models)
env_var = ENV_VAR_MAP[provider]
api_key = os.getenv(env_var)
if "llm_configured" not in st.session_state:
st.session_state.llm_configured = False
if api_key:
st.session_state.llm_configured = True
else:
st.session_state.llm_configured = False
with st.form(key="api_key_form"):
entered_key = st.text_input(f"Enter your {provider} API Key", type="password")
submitted = st.form_submit_button("Submit and Confirm")
if submitted:
if entered_key:
os.environ[env_var] = entered_key
api_key = entered_key
st.session_state.llm_configured = True
st.rerun()
else:
st.warning("Please enter a key.")
if not st.session_state.llm_configured:
st.warning("Please configure your LLM settings in the sidebar.")
st.stop()
if api_key and not st.session_state.llm_configured:
st.session_state.llm_configured = True
if not api_key:
st.session_state.llm_configured = False
if not st.session_state.llm_configured:
st.warning("Please configure your LLM settings in the sidebar.")
st.stop()
# Initialize LLM with the provided key
llm = get_llm(provider=provider, model=model, api_key=api_key)
# 3. Load Data
@st.cache_resource
def load_resources(beta_path, vocab_path, docs_path, index_path, time2id_path, length_stats_path, lemma_map_path):
beta, vocab = load_beta_matrix(beta_path, vocab_path)
index, docs, lemma_to_forms = load_index(docs_file_path=docs_path, vocab=vocab, index_path=index_path, lemma_map_path=lemma_map_path)
time_labels = load_time_labels(time2id_path)
length_stats = load_length_stats(length_stats_path)
return beta, vocab, index, docs, lemma_to_forms, time_labels, length_stats
# --- Main Title and Paper-aligned Intro ---
st.markdown("""# π DTECT: Dynamic Topic Explorer & Context Tracker""")
# --- Load resources ---
try:
beta, vocab, index, docs, lemma_to_forms, time_labels, length_stats = load_resources(
beta_path,
vocab_path,
docs_path,
index_path,
time2id_path,
length_stats_path,
lemma_map_path
)
except FileNotFoundError as e:
st.error(f"Missing required file: {e}")
st.stop()
except Exception as e:
st.error(f"Failed to load data: {str(e)}")
st.stop()
timestamps = list(range(len(time_labels)))
num_topics = beta.shape[1]
# Estimate max_k based on document length stats and selected LLM
suggested_max_k = estimate_k_max_from_word_stats(length_stats.get("avg_len"), model_name=model, provider=provider)
# ==============================================================================
# 1. π· TOPIC LABELING
# ==============================================================================
st.markdown("## 1οΈβ£ π·οΈ Topic Labeling")
st.info("Topics are automatically labeled using LLMs by analyzing their temporal word distributions.")
with st.spinner("β¨ Generating topic labels... LLM will be used only if labels are not cached."):
topic_labels = get_topic_labels(beta, vocab, time_labels, llm, label_cache_path)
topic_options = list(topic_labels.values())
selected_topic_label = st.selectbox("Select a Topic", topic_options, help="LLM-generated topic label")
label_to_topic = {v: k for k, v in topic_labels.items()}
selected_topic = label_to_topic[selected_topic_label]
# ==============================================================================
# 2. π‘ INFORMATIVE WORD DETECTION & π TREND VISUALIZATION
# ==============================================================================
st.markdown("---")
st.markdown("## 2οΈβ£ π‘ Informative Word Detection & π Trend Visualization")
st.info("Explore top/interesting words for each topic, and visualize their trends over time.")
top_n_words = st.slider("Number of Top Words per Topic", min_value=5, max_value=500, value=500)
top_words = get_top_words_over_time(
beta=beta,
vocab=vocab,
topic_id=selected_topic,
top_n=top_n_words
)
st.write(f"### Top {top_n_words} Words for Topic '{selected_topic_label}' (Ranked):")
scrollable_top_words = "<div style='max-height: 200px; overflow-y: auto; padding: 0 10px;'>"
words_per_col = (top_n_words + 3) // 4
columns = [top_words[i:i+words_per_col] for i in range(0, len(top_words), words_per_col)]
scrollable_top_words += "<div style='display: flex; gap: 20px;'>"
word_rank = 1
for col in columns:
scrollable_top_words += "<div style='flex: 1;'>"
for word in col:
scrollable_top_words += f"<div style='margin-bottom: 4px;'>{word_rank}. {word}</div>"
word_rank += 1
scrollable_top_words += "</div>"
scrollable_top_words += "</div></div>"
st.markdown(scrollable_top_words, unsafe_allow_html=True)
st.markdown("<div style='margin-top: 18px;'></div>", unsafe_allow_html=True)
if st.button("π‘ Suggest Informative Words", key="suggest_topic_words"):
top_words = get_top_words_over_time(
beta=beta,
vocab=vocab,
topic_id=selected_topic,
top_n=top_n_words
)
interesting_words = get_interesting_words(beta, vocab, topic_id=selected_topic, restrict_to=top_words)
st.session_state.interesting_words = interesting_words
st.session_state.selected_words = interesting_words[:15] # pre-fill multiselect
styled_words = " ".join([
f"<span style='background-color:#e0f7fa; color:#004d40; font-weight:500; padding:4px 8px; margin:4px; border-radius:8px; display:inline-block;'>{w}</span>"
for w in interesting_words
])
st.markdown(
f"**Top Informative Words from Topic '{selected_topic_label}':**<br>{styled_words}",
unsafe_allow_html=True
)
st.markdown("#### π Plot Word Trends Over Time")
all_word_options = vocab
interesting_words = st.session_state.get("interesting_words", [])
if "selected_words" not in st.session_state:
st.session_state.selected_words = interesting_words[:15] # initial default
selected_words = st.multiselect(
"Select words to visualize trends",
options=all_word_options,
default=st.session_state.selected_words,
key="selected_words"
)
if selected_words:
fig = go.Figure()
color_cycle = pc.qualitative.Plotly
for i, word in enumerate(selected_words):
trend = get_word_trend(beta, vocab, word, topic_id=selected_topic)
color = color_cycle[i % len(color_cycle)]
# --- START: Modify this line ---
fig.add_trace(go.Scatter(
x=time_labels,
y=trend,
name=word,
mode='lines+markers', # Explicitly add markers to the lines
line=dict(color=color),
legendgroup=word,
showlegend=True
))
fig.update_layout(
title="",
xaxis_title="Year",
yaxis_title="Importance",
legend=dict(
font=dict(
size=16
)
)
)
_, chart_col, _ = st.columns([0.2, 0.6, 0.2])
with chart_col:
st.plotly_chart(fig, use_container_width=True, theme=None)
# ==============================================================================
# 3. π DOCUMENT RETRIEVAL & π SUMMARIZATION
# ==============================================================================
st.markdown("---")
st.markdown("## 3οΈβ£ π Document Retrieval & π Summarization")
st.info("Retrieve and summarize documents matching selected words and years.")
if selected_words:
st.markdown("#### π Document Frequency Over Time")
selected_words_for_counts = st.multiselect(
"Select word(s) to show document frequencies over time",
options=selected_words,
default=selected_words[:3],
key="word_counts_multiselect"
)
if selected_words_for_counts:
color_cycle = pc.qualitative.Set2
bar_fig = go.Figure()
for i, word in enumerate(selected_words_for_counts):
doc_years, doc_counts = get_yearly_counts_for_word(index=index, word=word)
bar_fig.add_trace(go.Bar(
x=doc_years,
y=doc_counts,
name=word,
marker_color=color_cycle[i % len(color_cycle)],
opacity=0.85
))
bar_fig.update_layout(
barmode="group",
title="Document Frequency Over Time",
xaxis_title="Year",
yaxis_title="Document Count",
xaxis=dict(
tickmode='linear',
dtick=1,
tickformat='d'
),
bargap=0.2
)
st.plotly_chart(bar_fig, use_container_width=True)
st.markdown("#### π Inspect Documents for Word-Year Pairs")
# selected_year = st.slider("Select year", min_value=int(time_labels[0]), max_value=int(time_labels[-1]), key="inspect_year_slider")
selected_year = st.selectbox(
"Select year",
options=time_labels, # Use the list of available time labels (years)
index=0, # Default to the first year in the list
key="inspect_year_selectbox"
)
collected_docs_raw = []
for word in selected_words_for_counts:
docs_for_word_year = get_all_documents_for_word_year(
index=index,
docs_file_path=docs_path,
word=word,
year=selected_year
)
for doc in docs_for_word_year:
doc["__word__"] = word
collected_docs_raw.extend(docs_for_word_year)
if collected_docs_raw:
st.session_state.collected_deduplicated_docs = deduplicate_docs(collected_docs_raw)
st.write(f"Found {len(collected_docs_raw)} matching documents, {len(st.session_state.collected_deduplicated_docs)} after deduplication.")
html_blocks = ""
for doc in st.session_state.collected_deduplicated_docs:
word = doc["__word__"]
full_text = html.escape(doc["text"])
snippet_text = extract_snippet(doc["text"], word)
highlighted_snippet = highlight_words(
snippet_text,
query_words=selected_words_for_counts,
lemma_to_forms=lemma_to_forms
)
html_blocks += f"""
<div style="margin-bottom: 14px; padding: 10px; background-color: #fffbe6; border: 1px solid #f0e6cc; border-radius: 6px;">
<div style="color: #333;"><strong>Match:</strong> {word} | <strong>Doc ID:</strong> {doc['id']} | <strong>Timestamp:</strong> {doc['timestamp']}</div>
<div style="margin-top: 4px; color: #444;"><em>Snippet:</em> {highlighted_snippet}</div>
<details style="margin-top: 4px;">
<summary style="cursor: pointer; color: #007acc;">Show full document</summary>
<pre style="white-space: pre-wrap; color: #111; background-color: #fffef5; padding: 8px; border: 1px solid #f0e6cc; border-radius: 4px;">{full_text}</pre>
</details>
</div>
"""
min_height = 120
max_height = 700
per_doc_height = 130
dynamic_height = min_height + per_doc_height * max(len(st.session_state.collected_deduplicated_docs) - 1, 0)
container_height = min(dynamic_height, max_height)
scrollable_html = f"""
<div style="overflow-y: auto; padding: 10px;
border: 1px solid #f0e6cc; border-radius: 6px;
background-color: #fffbe6; color: #222;
margin-bottom: 0;">
{html_blocks}
</div>
"""
components.html(scrollable_html, height=container_height, scrolling=True)
else:
st.warning("No documents found for the selected words and year.")
# ==============================================================================
# 4. π¬ CHAT ASSISTANT (Summary & Follow-up)
# ==============================================================================
st.markdown("---")
st.markdown("## 4οΈβ£ π¬ Chat Assistant")
st.info("Generate summaries from the inspected documents and ask follow-up questions.")
if "summary" not in st.session_state:
st.session_state.summary = None
if "context_for_followup" not in st.session_state:
st.session_state.context_for_followup = ""
if "followup_history" not in st.session_state:
st.session_state.followup_history = []
# MMR K selection
st.markdown(f"**Max documents for summarization (k):**")
st.markdown(f"The suggested maximum number of documents for summarization (k) based on the average document length and the selected LLM is **{suggested_max_k}**.")
mmr_k = st.slider(
"Select the maximum number of documents (k) for MMR (Maximum Marginal Relevance) selection for summarization.",
min_value=1,
max_value=20, # Set a reasonable max for k, can be adjusted
value=min(suggested_max_k, 20), # Use suggested_max_k as default, capped at 20
help="This value determines how many relevant and diverse documents will be selected for summarization."
)
if st.button("π Summarize These Documents"):
if st.session_state.get("collected_deduplicated_docs"):
st.session_state.summary = None
st.session_state.context_for_followup = ""
st.session_state.followup_history = []
with st.spinner("Selecting and summarizing documents..."):
summary, mmr_docs = summarize_multiword_docs(
selected_words_for_counts,
selected_year,
st.session_state.collected_deduplicated_docs,
llm,
k=mmr_k
)
st.session_state.summary = summary
st.session_state.context_for_followup = "\n".join(
f"Document {i+1}:\n{doc.page_content.strip()}" for i, doc in enumerate(mmr_docs)
)
st.session_state.followup_history.append(
{"role": "user", "content": f"Please summarize the context of the words '{', '.join(selected_words_for_counts)}' in {selected_year} based on the provided documents."}
)
st.session_state.followup_history.append(
{"role": "assistant", "content": st.session_state.summary}
)
st.success(f"β
Summary generated from {len(mmr_docs)} MMR-selected documents.")
else:
st.warning("β οΈ No documents collected to summarize. Please inspect some documents first.")
if st.session_state.summary:
st.markdown(f"**Summary for words `{', '.join(selected_words_for_counts)}` in `{selected_year}`:**")
st.write(st.session_state.summary)
if st.checkbox("π¬ Ask follow-up questions about this summary", key="enable_followup"):
with st.expander("View the documents used for this conversation"):
st.text_area("Context Documents", st.session_state.context_for_followup, height=200)
st.info("Ask a question based on the summary and the documents above.")
for msg in st.session_state.followup_history[2:]:
with st.chat_message(msg["role"], avatar="π§" if msg["role"] == "user" else "π€"):
st.markdown(msg["content"])
if user_query := st.chat_input("Ask a follow-up question..."):
with st.chat_message("user", avatar="π§"):
st.markdown(user_query)
st.session_state.followup_history.append({"role": "user", "content": user_query})
with st.spinner("Thinking..."):
followup_response = ask_multiturn_followup(
history=st.session_state.followup_history,
question=user_query,
llm=llm,
context_texts=st.session_state.context_for_followup
)
st.session_state.followup_history.append({"role": "assistant", "content": followup_response})
if followup_response.startswith("[Error"):
st.error(followup_response)
else:
with st.chat_message("assistant", avatar="π€"):
st.markdown(followup_response)
st.rerun() |