Spaces:
Build error
Build error
Merge master into main, resolved conflicts and updated LFS tracking
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +5 -0
- .huggingface.yaml +3 -0
- LICENSE +21 -0
- app/ui_updated.py +450 -0
- assets/Logo_light.png +3 -0
- backend/__init__.py +81 -0
- backend/datasets/_preprocess.py +447 -0
- backend/datasets/data/download.py +32 -0
- backend/datasets/data/file_utils.py +39 -0
- backend/datasets/dynamic_dataset.py +90 -0
- backend/datasets/preprocess.py +362 -0
- backend/datasets/utils/_utils.py +37 -0
- backend/datasets/utils/logger.py +29 -0
- backend/evaluation/CoherenceModel_ttc.py +862 -0
- backend/evaluation/eval.py +179 -0
- backend/inference/doc_retriever.py +219 -0
- backend/inference/indexing_utils.py +146 -0
- backend/inference/peak_detector.py +18 -0
- backend/inference/process_beta.py +33 -0
- backend/inference/word_selector.py +102 -0
- backend/llm/custom_gemini.py +28 -0
- backend/llm/custom_mistral.py +27 -0
- backend/llm/llm_router.py +73 -0
- backend/llm_utils/label_generator.py +72 -0
- backend/llm_utils/summarizer.py +192 -0
- backend/llm_utils/token_utils.py +167 -0
- backend/models/CFDTM/CFDTM.py +127 -0
- backend/models/CFDTM/ETC.py +62 -0
- backend/models/CFDTM/Encoder.py +40 -0
- backend/models/CFDTM/UWE.py +48 -0
- backend/models/CFDTM/__init__.py +0 -0
- backend/models/CFDTM/__pycache__/CFDTM.cpython-39.pyc +0 -0
- backend/models/CFDTM/__pycache__/ETC.cpython-39.pyc +0 -0
- backend/models/CFDTM/__pycache__/Encoder.cpython-39.pyc +0 -0
- backend/models/CFDTM/__pycache__/UWE.cpython-39.pyc +0 -0
- backend/models/CFDTM/__pycache__/__init__.cpython-39.pyc +0 -0
- backend/models/DBERTopic_trainer.py +99 -0
- backend/models/DETM.py +259 -0
- backend/models/DTM_trainer.py +148 -0
- backend/models/dynamic_trainer.py +177 -0
- data/ACL_Anthology/CFDTM/beta.npy +3 -0
- data/ACL_Anthology/DETM/beta.npy +3 -0
- data/ACL_Anthology/DTM/beta.npy +3 -0
- data/ACL_Anthology/DTM/topic_label_cache.json +3 -0
- data/ACL_Anthology/docs.jsonl +3 -0
- data/ACL_Anthology/inverted_index.json +3 -0
- data/ACL_Anthology/processed/lemma_to_forms.json +3 -0
- data/ACL_Anthology/processed/length_stats.json +3 -0
- data/ACL_Anthology/processed/time2id.txt +18 -0
- data/ACL_Anthology/processed/vocab.txt +0 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
data/**/*.npy filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
data/**/*.jsonl filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
data/**/*.json filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
assets/*.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
data/**/*.npz filter=lfs diff=lfs merge=lfs -text
|
.huggingface.yaml
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# .huggingface.yaml
|
| 2 |
+
sdk: streamlit # or gradio
|
| 3 |
+
app_file: ./app/ui.py
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 Suman Adhya
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
app/ui_updated.py
ADDED
|
@@ -0,0 +1,450 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import plotly.graph_objects as go
|
| 3 |
+
import plotly.colors as pc
|
| 4 |
+
import sys
|
| 5 |
+
import os
|
| 6 |
+
import base64
|
| 7 |
+
import streamlit.components.v1 as components
|
| 8 |
+
import html
|
| 9 |
+
|
| 10 |
+
# Absolute path to the repo root (assuming `ui.py` is in /app)
|
| 11 |
+
REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
| 12 |
+
sys.path.append(REPO_ROOT)
|
| 13 |
+
ASSETS_DIR = os.path.join(REPO_ROOT, 'assets')
|
| 14 |
+
DATA_DIR = os.path.join(REPO_ROOT, 'data')
|
| 15 |
+
|
| 16 |
+
# Import functions from the backend
|
| 17 |
+
from backend.inference.process_beta import (
|
| 18 |
+
load_beta_matrix,
|
| 19 |
+
get_top_words_over_time,
|
| 20 |
+
load_time_labels
|
| 21 |
+
)
|
| 22 |
+
from backend.inference.word_selector import get_interesting_words, get_word_trend
|
| 23 |
+
from backend.inference.indexing_utils import load_index
|
| 24 |
+
from backend.inference.doc_retriever import (
|
| 25 |
+
load_length_stats,
|
| 26 |
+
get_yearly_counts_for_word,
|
| 27 |
+
deduplicate_docs,
|
| 28 |
+
get_all_documents_for_word_year,
|
| 29 |
+
highlight_words,
|
| 30 |
+
extract_snippet
|
| 31 |
+
)
|
| 32 |
+
from backend.llm_utils.summarizer import summarize_multiword_docs, ask_multiturn_followup
|
| 33 |
+
from backend.llm_utils.label_generator import get_topic_labels
|
| 34 |
+
from backend.llm.llm_router import get_llm, list_supported_models
|
| 35 |
+
from backend.llm_utils.token_utils import estimate_k_max_from_word_stats
|
| 36 |
+
|
| 37 |
+
def get_base64_image(image_path):
|
| 38 |
+
with open(image_path, "rb") as img_file:
|
| 39 |
+
return base64.b64encode(img_file.read()).decode()
|
| 40 |
+
|
| 41 |
+
# --- Page Configuration ---
|
| 42 |
+
st.set_page_config(
|
| 43 |
+
page_title="DTECT",
|
| 44 |
+
page_icon="🔍",
|
| 45 |
+
layout="wide"
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
# Sidebar branding and repo link
|
| 49 |
+
st.sidebar.markdown(
|
| 50 |
+
"""
|
| 51 |
+
<div style="text-align: center;">
|
| 52 |
+
<a href="https://github.com/dinb-ai/DTECT" target="_blank">
|
| 53 |
+
<img src="data:image/png;base64,{}" width="180" style="margin-bottom: 18px;">
|
| 54 |
+
</a>
|
| 55 |
+
<hr style="margin-bottom: 0;">
|
| 56 |
+
</div>
|
| 57 |
+
""".format(get_base64_image(os.path.join(ASSETS_DIR, 'Logo_light.png'))),
|
| 58 |
+
unsafe_allow_html=True
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
# 1. Sidebar: Model and Dataset Selection
|
| 62 |
+
st.sidebar.title("Configuration")
|
| 63 |
+
|
| 64 |
+
AVAILABLE_MODELS = ["DTM", "DETM", "CFDTM"]
|
| 65 |
+
ENV_VAR_MAP = {
|
| 66 |
+
"OpenAI": "OPENAI_API_KEY",
|
| 67 |
+
"Anthropic": "ANTHROPIC_API_KEY",
|
| 68 |
+
"Gemini": "GEMINI_API_KEY",
|
| 69 |
+
"Mistral": "MISTRAL_API_KEY"
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
def list_datasets(data_dir):
|
| 73 |
+
return sorted([
|
| 74 |
+
name for name in os.listdir(data_dir)
|
| 75 |
+
if os.path.isdir(os.path.join(data_dir, name))
|
| 76 |
+
])
|
| 77 |
+
|
| 78 |
+
with st.sidebar.expander("Select Dataset & Topic Model", expanded=True):
|
| 79 |
+
datasets = list_datasets(DATA_DIR)
|
| 80 |
+
selected_dataset = st.selectbox("Dataset", datasets, help="Choose an available dataset.")
|
| 81 |
+
selected_model = st.selectbox("Model", AVAILABLE_MODELS, help="Select topic model architecture.")
|
| 82 |
+
|
| 83 |
+
# Resolve paths
|
| 84 |
+
dataset_path = os.path.join(DATA_DIR, selected_dataset)
|
| 85 |
+
model_path = os.path.join(dataset_path, selected_model)
|
| 86 |
+
docs_path = os.path.join(dataset_path, "docs.jsonl")
|
| 87 |
+
vocab_path = os.path.join(dataset_path, "processed/vocab.txt")
|
| 88 |
+
time2id_path = os.path.join(dataset_path, "processed/time2id.txt")
|
| 89 |
+
index_path = os.path.join(dataset_path, "inverted_index.json")
|
| 90 |
+
beta_path = os.path.join(model_path, "beta.npy")
|
| 91 |
+
label_cache_path = os.path.join(model_path, "topic_label_cache.json")
|
| 92 |
+
length_stats_path = os.path.join(dataset_path, "processed/length_stats.json")
|
| 93 |
+
lemma_map_path = os.path.join(dataset_path, "processed/lemma_to_forms.json")
|
| 94 |
+
|
| 95 |
+
with st.sidebar.expander("LLM Settings", expanded=True):
|
| 96 |
+
provider = st.selectbox("LLM Provider", options=list(ENV_VAR_MAP.keys()), help="Choose the LLM backend.")
|
| 97 |
+
available_models = list_supported_models(provider)
|
| 98 |
+
model = st.selectbox("LLM Model", options=available_models)
|
| 99 |
+
env_var = ENV_VAR_MAP[provider]
|
| 100 |
+
api_key = os.getenv(env_var)
|
| 101 |
+
|
| 102 |
+
if "llm_configured" not in st.session_state:
|
| 103 |
+
st.session_state.llm_configured = False
|
| 104 |
+
|
| 105 |
+
if api_key:
|
| 106 |
+
st.session_state.llm_configured = True
|
| 107 |
+
else:
|
| 108 |
+
st.session_state.llm_configured = False
|
| 109 |
+
with st.form(key="api_key_form"):
|
| 110 |
+
entered_key = st.text_input(f"Enter your {provider} API Key", type="password")
|
| 111 |
+
submitted = st.form_submit_button("Submit and Confirm")
|
| 112 |
+
if submitted:
|
| 113 |
+
if entered_key:
|
| 114 |
+
os.environ[env_var] = entered_key
|
| 115 |
+
api_key = entered_key
|
| 116 |
+
st.session_state.llm_configured = True
|
| 117 |
+
st.rerun()
|
| 118 |
+
else:
|
| 119 |
+
st.warning("Please enter a key.")
|
| 120 |
+
|
| 121 |
+
if not st.session_state.llm_configured:
|
| 122 |
+
st.warning("Please configure your LLM settings in the sidebar.")
|
| 123 |
+
st.stop()
|
| 124 |
+
|
| 125 |
+
if api_key and not st.session_state.llm_configured:
|
| 126 |
+
st.session_state.llm_configured = True
|
| 127 |
+
|
| 128 |
+
if not api_key:
|
| 129 |
+
st.session_state.llm_configured = False
|
| 130 |
+
|
| 131 |
+
if not st.session_state.llm_configured:
|
| 132 |
+
st.warning("Please configure your LLM settings in the sidebar.")
|
| 133 |
+
st.stop()
|
| 134 |
+
|
| 135 |
+
# Initialize LLM with the provided key
|
| 136 |
+
llm = get_llm(provider=provider, model=model, api_key=api_key)
|
| 137 |
+
|
| 138 |
+
# 3. Load Data
|
| 139 |
+
@st.cache_resource
|
| 140 |
+
def load_resources(beta_path, vocab_path, docs_path, index_path, time2id_path, length_stats_path, lemma_map_path):
|
| 141 |
+
beta, vocab = load_beta_matrix(beta_path, vocab_path)
|
| 142 |
+
index, docs, lemma_to_forms = load_index(docs_file_path=docs_path, vocab=vocab, index_path=index_path, lemma_map_path=lemma_map_path)
|
| 143 |
+
time_labels = load_time_labels(time2id_path)
|
| 144 |
+
length_stats = load_length_stats(length_stats_path)
|
| 145 |
+
return beta, vocab, index, docs, lemma_to_forms, time_labels, length_stats
|
| 146 |
+
|
| 147 |
+
# --- Main Title and Paper-aligned Intro ---
|
| 148 |
+
st.markdown("""# 🔍 DTECT: Dynamic Topic Explorer & Context Tracker""")
|
| 149 |
+
|
| 150 |
+
# --- Load resources ---
|
| 151 |
+
try:
|
| 152 |
+
beta, vocab, index, docs, lemma_to_forms, time_labels, length_stats = load_resources(
|
| 153 |
+
beta_path,
|
| 154 |
+
vocab_path,
|
| 155 |
+
docs_path,
|
| 156 |
+
index_path,
|
| 157 |
+
time2id_path,
|
| 158 |
+
length_stats_path,
|
| 159 |
+
lemma_map_path
|
| 160 |
+
)
|
| 161 |
+
except FileNotFoundError as e:
|
| 162 |
+
st.error(f"Missing required file: {e}")
|
| 163 |
+
st.stop()
|
| 164 |
+
except Exception as e:
|
| 165 |
+
st.error(f"Failed to load data: {str(e)}")
|
| 166 |
+
st.stop()
|
| 167 |
+
|
| 168 |
+
timestamps = list(range(len(time_labels)))
|
| 169 |
+
num_topics = beta.shape[1]
|
| 170 |
+
# Estimate max_k based on document length stats and selected LLM
|
| 171 |
+
suggested_max_k = estimate_k_max_from_word_stats(length_stats.get("avg_len"), model_name=model, provider=provider)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
# ==============================================================================
|
| 175 |
+
# 1. 🏷 TOPIC LABELING
|
| 176 |
+
# ==============================================================================
|
| 177 |
+
st.markdown("## 1️⃣ 🏷️ Topic Labeling")
|
| 178 |
+
st.info("Topics are automatically labeled using LLMs by analyzing their temporal word distributions.")
|
| 179 |
+
|
| 180 |
+
topic_labels = get_topic_labels(beta, vocab, time_labels, llm, label_cache_path)
|
| 181 |
+
topic_options = list(topic_labels.values())
|
| 182 |
+
selected_topic_label = st.selectbox("Select a Topic", topic_options, help="LLM-generated topic label")
|
| 183 |
+
label_to_topic = {v: k for k, v in topic_labels.items()}
|
| 184 |
+
selected_topic = label_to_topic[selected_topic_label]
|
| 185 |
+
|
| 186 |
+
# ==============================================================================
|
| 187 |
+
# 2. 💡 INFORMATIVE WORD DETECTION & 📊 TREND VISUALIZATION
|
| 188 |
+
# ==============================================================================
|
| 189 |
+
st.markdown("---")
|
| 190 |
+
st.markdown("## 2️⃣ 💡 Informative Word Detection & 📊 Trend Visualization")
|
| 191 |
+
st.info("Explore top/interesting words for each topic, and visualize their trends over time.")
|
| 192 |
+
|
| 193 |
+
top_n_words = st.slider("Number of Top Words per Topic", min_value=5, max_value=500, value=10)
|
| 194 |
+
top_words = get_top_words_over_time(
|
| 195 |
+
beta=beta,
|
| 196 |
+
vocab=vocab,
|
| 197 |
+
topic_id=selected_topic,
|
| 198 |
+
top_n=top_n_words
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
st.write(f"### Top {top_n_words} Words for Topic '{selected_topic_label}' (Ranked):")
|
| 202 |
+
scrollable_top_words = "<div style='max-height: 200px; overflow-y: auto; padding: 0 10px;'>"
|
| 203 |
+
words_per_col = (top_n_words + 3) // 4
|
| 204 |
+
columns = [top_words[i:i+words_per_col] for i in range(0, len(top_words), words_per_col)]
|
| 205 |
+
scrollable_top_words += "<div style='display: flex; gap: 20px;'>"
|
| 206 |
+
word_rank = 1
|
| 207 |
+
for col in columns:
|
| 208 |
+
scrollable_top_words += "<div style='flex: 1;'>"
|
| 209 |
+
for word in col:
|
| 210 |
+
scrollable_top_words += f"<div style='margin-bottom: 4px;'>{word_rank}. {word}</div>"
|
| 211 |
+
word_rank += 1
|
| 212 |
+
scrollable_top_words += "</div>"
|
| 213 |
+
scrollable_top_words += "</div></div>"
|
| 214 |
+
st.markdown(scrollable_top_words, unsafe_allow_html=True)
|
| 215 |
+
|
| 216 |
+
st.markdown("<div style='margin-top: 18px;'></div>", unsafe_allow_html=True)
|
| 217 |
+
|
| 218 |
+
if st.button("💡 Suggest Informative Words", key="suggest_topic_words"):
|
| 219 |
+
top_words = get_top_words_over_time(
|
| 220 |
+
beta=beta,
|
| 221 |
+
vocab=vocab,
|
| 222 |
+
topic_id=selected_topic,
|
| 223 |
+
top_n=top_n_words
|
| 224 |
+
)
|
| 225 |
+
interesting_words = get_interesting_words(beta, vocab, topic_id=selected_topic, restrict_to=top_words)
|
| 226 |
+
st.session_state.interesting_words = interesting_words
|
| 227 |
+
st.session_state.selected_words = interesting_words[:15] # pre-fill multiselect
|
| 228 |
+
styled_words = " ".join([
|
| 229 |
+
f"<span style='background-color:#e0f7fa; color:#004d40; font-weight:500; padding:4px 8px; margin:4px; border-radius:8px; display:inline-block;'>{w}</span>"
|
| 230 |
+
for w in interesting_words
|
| 231 |
+
])
|
| 232 |
+
st.markdown(
|
| 233 |
+
f"**Top Informative Words from Topic '{selected_topic_label}':**<br>{styled_words}",
|
| 234 |
+
unsafe_allow_html=True
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
st.markdown("#### 📈 Plot Word Trends Over Time")
|
| 238 |
+
all_word_options = vocab
|
| 239 |
+
interesting_words = st.session_state.get("interesting_words", [])
|
| 240 |
+
|
| 241 |
+
if "selected_words" not in st.session_state:
|
| 242 |
+
st.session_state.selected_words = interesting_words[:15] # initial default
|
| 243 |
+
|
| 244 |
+
selected_words = st.multiselect(
|
| 245 |
+
"Select words to visualize trends",
|
| 246 |
+
options=all_word_options,
|
| 247 |
+
default=st.session_state.selected_words,
|
| 248 |
+
key="selected_words"
|
| 249 |
+
)
|
| 250 |
+
if selected_words:
|
| 251 |
+
fig = go.Figure()
|
| 252 |
+
color_cycle = pc.qualitative.Plotly
|
| 253 |
+
for i, word in enumerate(selected_words):
|
| 254 |
+
trend = get_word_trend(beta, vocab, word, topic_id=selected_topic)
|
| 255 |
+
color = color_cycle[i % len(color_cycle)]
|
| 256 |
+
fig.add_trace(go.Scatter(
|
| 257 |
+
x=time_labels,
|
| 258 |
+
y=trend,
|
| 259 |
+
name=word,
|
| 260 |
+
line=dict(color=color),
|
| 261 |
+
legendgroup=word,
|
| 262 |
+
showlegend=True
|
| 263 |
+
))
|
| 264 |
+
fig.update_layout(title="", xaxis_title="Year", yaxis_title="Importance")
|
| 265 |
+
st.plotly_chart(fig, use_container_width=True)
|
| 266 |
+
|
| 267 |
+
# ==============================================================================
|
| 268 |
+
# 3. 🔍 DOCUMENT RETRIEVAL & 📃 SUMMARIZATION
|
| 269 |
+
# ==============================================================================
|
| 270 |
+
st.markdown("---")
|
| 271 |
+
st.markdown("## 3️⃣ 🔍 Document Retrieval & 📃 Summarization")
|
| 272 |
+
st.info("Retrieve and summarize documents matching selected words and years.")
|
| 273 |
+
|
| 274 |
+
if selected_words:
|
| 275 |
+
st.markdown("#### 📊 Document Frequency Over Time")
|
| 276 |
+
selected_words_for_counts = st.multiselect(
|
| 277 |
+
"Select word(s) to show document frequencies over time",
|
| 278 |
+
options=selected_words,
|
| 279 |
+
default=selected_words[:3],
|
| 280 |
+
key="word_counts_multiselect"
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
if selected_words_for_counts:
|
| 284 |
+
color_cycle = pc.qualitative.Set2
|
| 285 |
+
bar_fig = go.Figure()
|
| 286 |
+
for i, word in enumerate(selected_words_for_counts):
|
| 287 |
+
doc_years, doc_counts = get_yearly_counts_for_word(index=index, word=word)
|
| 288 |
+
bar_fig.add_trace(go.Bar(
|
| 289 |
+
x=doc_years,
|
| 290 |
+
y=doc_counts,
|
| 291 |
+
name=word,
|
| 292 |
+
marker_color=color_cycle[i % len(color_cycle)],
|
| 293 |
+
opacity=0.85
|
| 294 |
+
))
|
| 295 |
+
bar_fig.update_layout(
|
| 296 |
+
barmode="group",
|
| 297 |
+
title="Document Frequency Over Time",
|
| 298 |
+
xaxis_title="Year",
|
| 299 |
+
yaxis_title="Document Count",
|
| 300 |
+
xaxis=dict(
|
| 301 |
+
tickmode='linear',
|
| 302 |
+
dtick=1,
|
| 303 |
+
tickformat='d'
|
| 304 |
+
),
|
| 305 |
+
bargap=0.2
|
| 306 |
+
)
|
| 307 |
+
st.plotly_chart(bar_fig, use_container_width=True)
|
| 308 |
+
|
| 309 |
+
st.markdown("#### 📄 Inspect Documents for Word-Year Pairs")
|
| 310 |
+
# selected_year = st.slider("Select year", min_value=int(time_labels[0]), max_value=int(time_labels[-1]), key="inspect_year_slider")
|
| 311 |
+
selected_year = st.selectbox(
|
| 312 |
+
"Select year",
|
| 313 |
+
options=time_labels, # Use the list of available time labels (years)
|
| 314 |
+
index=0, # Default to the first year in the list
|
| 315 |
+
key="inspect_year_selectbox"
|
| 316 |
+
)
|
| 317 |
+
collected_docs_raw = []
|
| 318 |
+
for word in selected_words_for_counts:
|
| 319 |
+
docs_for_word_year = get_all_documents_for_word_year(
|
| 320 |
+
index=index,
|
| 321 |
+
docs_file_path=docs_path,
|
| 322 |
+
word=word,
|
| 323 |
+
year=selected_year
|
| 324 |
+
)
|
| 325 |
+
for doc in docs_for_word_year:
|
| 326 |
+
doc["__word__"] = word
|
| 327 |
+
collected_docs_raw.extend(docs_for_word_year)
|
| 328 |
+
|
| 329 |
+
if collected_docs_raw:
|
| 330 |
+
st.session_state.collected_deduplicated_docs = deduplicate_docs(collected_docs_raw)
|
| 331 |
+
st.write(f"Found {len(collected_docs_raw)} matching documents, {len(st.session_state.collected_deduplicated_docs)} after deduplication.")
|
| 332 |
+
|
| 333 |
+
html_blocks = ""
|
| 334 |
+
for doc in st.session_state.collected_deduplicated_docs:
|
| 335 |
+
word = doc["__word__"]
|
| 336 |
+
full_text = html.escape(doc["text"])
|
| 337 |
+
snippet_text = extract_snippet(doc["text"], word)
|
| 338 |
+
highlighted_snippet = highlight_words(
|
| 339 |
+
snippet_text,
|
| 340 |
+
query_words=selected_words_for_counts,
|
| 341 |
+
lemma_to_forms=lemma_to_forms
|
| 342 |
+
)
|
| 343 |
+
html_blocks += f"""
|
| 344 |
+
<div style="margin-bottom: 14px; padding: 10px; background-color: #fffbe6; border: 1px solid #f0e6cc; border-radius: 6px;">
|
| 345 |
+
<div style="color: #333;"><strong>Match:</strong> {word} | <strong>Doc ID:</strong> {doc['id']} | <strong>Timestamp:</strong> {doc['timestamp']}</div>
|
| 346 |
+
<div style="margin-top: 4px; color: #444;"><em>Snippet:</em> {highlighted_snippet}</div>
|
| 347 |
+
<details style="margin-top: 4px;">
|
| 348 |
+
<summary style="cursor: pointer; color: #007acc;">Show full document</summary>
|
| 349 |
+
<pre style="white-space: pre-wrap; color: #111; background-color: #fffef5; padding: 8px; border: 1px solid #f0e6cc; border-radius: 4px;">{full_text}</pre>
|
| 350 |
+
</details>
|
| 351 |
+
</div>
|
| 352 |
+
"""
|
| 353 |
+
min_height = 120
|
| 354 |
+
max_height = 700
|
| 355 |
+
per_doc_height = 130
|
| 356 |
+
dynamic_height = min_height + per_doc_height * max(len(st.session_state.collected_deduplicated_docs) - 1, 0)
|
| 357 |
+
container_height = min(dynamic_height, max_height)
|
| 358 |
+
scrollable_html = f"""
|
| 359 |
+
<div style="overflow-y: auto; padding: 10px;
|
| 360 |
+
border: 1px solid #f0e6cc; border-radius: 6px;
|
| 361 |
+
background-color: #fffbe6; color: #222;
|
| 362 |
+
margin-bottom: 0;">
|
| 363 |
+
{html_blocks}
|
| 364 |
+
</div>
|
| 365 |
+
"""
|
| 366 |
+
components.html(scrollable_html, height=container_height, scrolling=True)
|
| 367 |
+
else:
|
| 368 |
+
st.warning("No documents found for the selected words and year.")
|
| 369 |
+
|
| 370 |
+
# ==============================================================================
|
| 371 |
+
# 4. 💬 CHAT ASSISTANT (Summary & Follow-up)
|
| 372 |
+
# ==============================================================================
|
| 373 |
+
st.markdown("---")
|
| 374 |
+
st.markdown("## 4️⃣ 💬 Chat Assistant")
|
| 375 |
+
st.info("Generate summaries from the inspected documents and ask follow-up questions.")
|
| 376 |
+
|
| 377 |
+
if "summary" not in st.session_state:
|
| 378 |
+
st.session_state.summary = None
|
| 379 |
+
if "context_for_followup" not in st.session_state:
|
| 380 |
+
st.session_state.context_for_followup = ""
|
| 381 |
+
if "followup_history" not in st.session_state:
|
| 382 |
+
st.session_state.followup_history = []
|
| 383 |
+
|
| 384 |
+
# MMR K selection
|
| 385 |
+
st.markdown(f"**Max documents for summarization (k):**")
|
| 386 |
+
st.markdown(f"The suggested maximum number of documents for summarization (k) based on the average document length and the selected LLM is **{suggested_max_k}**.")
|
| 387 |
+
mmr_k = st.slider(
|
| 388 |
+
"Select the maximum number of documents (k) for MMR (Maximum Marginal Relevance) selection for summarization.",
|
| 389 |
+
min_value=1,
|
| 390 |
+
max_value=20, # Set a reasonable max for k, can be adjusted
|
| 391 |
+
value=min(suggested_max_k, 20), # Use suggested_max_k as default, capped at 20
|
| 392 |
+
help="This value determines how many relevant and diverse documents will be selected for summarization."
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
if st.button("📃 Summarize These Documents"):
|
| 396 |
+
if st.session_state.get("collected_deduplicated_docs"):
|
| 397 |
+
st.session_state.summary = None
|
| 398 |
+
st.session_state.context_for_followup = ""
|
| 399 |
+
st.session_state.followup_history = []
|
| 400 |
+
with st.spinner("Selecting and summarizing documents..."):
|
| 401 |
+
summary, mmr_docs = summarize_multiword_docs(
|
| 402 |
+
selected_words_for_counts,
|
| 403 |
+
selected_year,
|
| 404 |
+
st.session_state.collected_deduplicated_docs,
|
| 405 |
+
llm,
|
| 406 |
+
k=mmr_k
|
| 407 |
+
)
|
| 408 |
+
st.session_state.summary = summary
|
| 409 |
+
st.session_state.context_for_followup = "\n".join(
|
| 410 |
+
f"Document {i+1}:\n{doc.page_content.strip()}" for i, doc in enumerate(mmr_docs)
|
| 411 |
+
)
|
| 412 |
+
st.session_state.followup_history.append(
|
| 413 |
+
{"role": "user", "content": f"Please summarize the context of the words '{', '.join(selected_words_for_counts)}' in {selected_year} based on the provided documents."}
|
| 414 |
+
)
|
| 415 |
+
st.session_state.followup_history.append(
|
| 416 |
+
{"role": "assistant", "content": st.session_state.summary}
|
| 417 |
+
)
|
| 418 |
+
st.success(f"✅ Summary generated from {len(mmr_docs)} MMR-selected documents.")
|
| 419 |
+
else:
|
| 420 |
+
st.warning("⚠️ No documents collected to summarize. Please inspect some documents first.")
|
| 421 |
+
|
| 422 |
+
if st.session_state.summary:
|
| 423 |
+
st.markdown(f"**Summary for words `{', '.join(selected_words_for_counts)}` in `{selected_year}`:**")
|
| 424 |
+
st.write(st.session_state.summary)
|
| 425 |
+
|
| 426 |
+
if st.checkbox("💬 Ask follow-up questions about this summary", key="enable_followup"):
|
| 427 |
+
with st.expander("View the documents used for this conversation"):
|
| 428 |
+
st.text_area("Context Documents", st.session_state.context_for_followup, height=200)
|
| 429 |
+
st.info("Ask a question based on the summary and the documents above.")
|
| 430 |
+
for msg in st.session_state.followup_history[2:]:
|
| 431 |
+
with st.chat_message(msg["role"], avatar="🧑" if msg["role"] == "user" else "🤖"):
|
| 432 |
+
st.markdown(msg["content"])
|
| 433 |
+
if user_query := st.chat_input("Ask a follow-up question..."):
|
| 434 |
+
with st.chat_message("user", avatar="🧑"):
|
| 435 |
+
st.markdown(user_query)
|
| 436 |
+
st.session_state.followup_history.append({"role": "user", "content": user_query})
|
| 437 |
+
with st.spinner("Thinking..."):
|
| 438 |
+
followup_response = ask_multiturn_followup(
|
| 439 |
+
history=st.session_state.followup_history,
|
| 440 |
+
question=user_query,
|
| 441 |
+
llm=llm,
|
| 442 |
+
context_texts=st.session_state.context_for_followup
|
| 443 |
+
)
|
| 444 |
+
st.session_state.followup_history.append({"role": "assistant", "content": followup_response})
|
| 445 |
+
if followup_response.startswith("[Error"):
|
| 446 |
+
st.error(followup_response)
|
| 447 |
+
else:
|
| 448 |
+
with st.chat_message("assistant", avatar="🤖"):
|
| 449 |
+
st.markdown(followup_response)
|
| 450 |
+
st.rerun()
|
assets/Logo_light.png
ADDED
|
Git LFS Details
|
backend/__init__.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# === Inference components ===
|
| 2 |
+
from .inference.process_beta import (
|
| 3 |
+
load_beta_matrix,
|
| 4 |
+
get_top_words_at_time,
|
| 5 |
+
get_top_words_over_time,
|
| 6 |
+
load_time_labels
|
| 7 |
+
)
|
| 8 |
+
|
| 9 |
+
from .inference.indexing_utils import load_index
|
| 10 |
+
from .inference.word_selector import (
|
| 11 |
+
get_interesting_words,
|
| 12 |
+
get_word_trend
|
| 13 |
+
)
|
| 14 |
+
from .inference.peak_detector import detect_peaks
|
| 15 |
+
from .inference.doc_retriever import (
|
| 16 |
+
load_length_stats,
|
| 17 |
+
get_yearly_counts_for_word,
|
| 18 |
+
get_all_documents_for_word_year,
|
| 19 |
+
deduplicate_docs,
|
| 20 |
+
extract_snippet,
|
| 21 |
+
highlight,
|
| 22 |
+
get_docs_by_ids,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
# === LLM components ===
|
| 26 |
+
from .llm_utils.label_generator import label_topic_temporal, get_topic_labels
|
| 27 |
+
from .llm_utils.token_utils import (
|
| 28 |
+
get_token_limit_for_model,
|
| 29 |
+
count_tokens,
|
| 30 |
+
estimate_avg_tokens_per_doc,
|
| 31 |
+
estimate_max_k,
|
| 32 |
+
estimate_max_k_fast
|
| 33 |
+
)
|
| 34 |
+
from .llm_utils.summarizer import (
|
| 35 |
+
summarize_docs,
|
| 36 |
+
summarize_multiword_docs,
|
| 37 |
+
ask_multiturn_followup
|
| 38 |
+
)
|
| 39 |
+
from .llm.llm_router import (
|
| 40 |
+
list_supported_models,
|
| 41 |
+
get_llm
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
# === Dataset utilities ===
|
| 45 |
+
from .datasets import dynamic_dataset
|
| 46 |
+
from .datasets import preprocess
|
| 47 |
+
from .datasets.utils import logger, _utils
|
| 48 |
+
from .datasets.data import file_utils, download
|
| 49 |
+
|
| 50 |
+
# === Evaluation ===
|
| 51 |
+
from .evaluation.CoherenceModel_ttc import CoherenceModel_ttc
|
| 52 |
+
from .evaluation.eval import TopicQualityAssessor
|
| 53 |
+
|
| 54 |
+
# === Models ===
|
| 55 |
+
from .models.DETM import DETM
|
| 56 |
+
from .models.DTM_trainer import DTMTrainer
|
| 57 |
+
from .models.CFDTM.CFDTM import CFDTM
|
| 58 |
+
from .models.dynamic_trainer import DynamicTrainer
|
| 59 |
+
|
| 60 |
+
__all__ = [
|
| 61 |
+
# Inference
|
| 62 |
+
"load_beta_matrix", "load_time_labels", "get_top_words_at_time", "get_top_words_over_time",
|
| 63 |
+
"load_index", "get_interesting_words", "get_word_trend", "detect_peaks",
|
| 64 |
+
"load_length_stats", "get_yearly_counts_for_word", "get_all_documents_for_word_year",
|
| 65 |
+
"deduplicate_docs", "extract_snippet", "highlight", "get_docs_by_ids",
|
| 66 |
+
|
| 67 |
+
# LLM
|
| 68 |
+
"summarize_docs", "summarize_multiword_docs", "ask_multiturn_followup",
|
| 69 |
+
"get_token_limit_for_model", "list_supported_models", "get_llm",
|
| 70 |
+
"label_topic_temporal", "get_topic_labels", "count_tokens",
|
| 71 |
+
"estimate_avg_tokens_per_doc", "estimate_max_k", "estimate_max_k_fast",
|
| 72 |
+
|
| 73 |
+
# Dataset
|
| 74 |
+
"dynamic_dataset", "preprocess", "logger","_utils", "file_utils", "download",
|
| 75 |
+
|
| 76 |
+
# Evaluation
|
| 77 |
+
"CoherenceModel_ttc", "TopicQualityAssessor",
|
| 78 |
+
|
| 79 |
+
# Models
|
| 80 |
+
"DETM", "DTMTrainer", "CFDTM", "DynamicTrainer"
|
| 81 |
+
]
|
backend/datasets/_preprocess.py
ADDED
|
@@ -0,0 +1,447 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
import string
|
| 4 |
+
import gensim.downloader
|
| 5 |
+
from collections import Counter
|
| 6 |
+
import numpy as np
|
| 7 |
+
import scipy.sparse
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
from sklearn.feature_extraction.text import CountVectorizer
|
| 10 |
+
|
| 11 |
+
from backend.datasets.data import file_utils
|
| 12 |
+
from backend.datasets.utils._utils import get_stopwords_set
|
| 13 |
+
from backend.datasets.utils.logger import Logger
|
| 14 |
+
import json
|
| 15 |
+
import nltk
|
| 16 |
+
from nltk.stem import WordNetLemmatizer
|
| 17 |
+
|
| 18 |
+
logger = Logger("WARNING")
|
| 19 |
+
|
| 20 |
+
try:
|
| 21 |
+
nltk.data.find('corpora/wordnet')
|
| 22 |
+
except LookupError:
|
| 23 |
+
nltk.download('wordnet', quiet=True)
|
| 24 |
+
try:
|
| 25 |
+
nltk.data.find('corpora/omw-1.4')
|
| 26 |
+
except LookupError:
|
| 27 |
+
nltk.download('omw-1.4', quiet=True)
|
| 28 |
+
|
| 29 |
+
# compile some regexes
|
| 30 |
+
punct_chars = list(set(string.punctuation) - set("'"))
|
| 31 |
+
punct_chars.sort()
|
| 32 |
+
punctuation = ''.join(punct_chars)
|
| 33 |
+
replace = re.compile('[%s]' % re.escape(punctuation))
|
| 34 |
+
alpha = re.compile('^[a-zA-Z_]+$')
|
| 35 |
+
alpha_or_num = re.compile('^[a-zA-Z_]+|[0-9_]+$')
|
| 36 |
+
alphanum = re.compile('^[a-zA-Z0-9_]+$')
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class Tokenizer:
|
| 40 |
+
def __init__(self,
|
| 41 |
+
stopwords="English",
|
| 42 |
+
keep_num=False,
|
| 43 |
+
keep_alphanum=False,
|
| 44 |
+
strip_html=False,
|
| 45 |
+
no_lower=False,
|
| 46 |
+
min_length=3,
|
| 47 |
+
lemmatize=True,
|
| 48 |
+
):
|
| 49 |
+
self.keep_num = keep_num
|
| 50 |
+
self.keep_alphanum = keep_alphanum
|
| 51 |
+
self.strip_html = strip_html
|
| 52 |
+
self.lower = not no_lower
|
| 53 |
+
self.min_length = min_length
|
| 54 |
+
|
| 55 |
+
self.stopword_set = get_stopwords_set(stopwords)
|
| 56 |
+
|
| 57 |
+
self.lemmatize = lemmatize
|
| 58 |
+
if lemmatize:
|
| 59 |
+
self.lemmatizer = WordNetLemmatizer()
|
| 60 |
+
|
| 61 |
+
def clean_text(self, text, strip_html=False, lower=True, keep_emails=False, keep_at_mentions=False):
|
| 62 |
+
# remove html tags
|
| 63 |
+
if strip_html:
|
| 64 |
+
text = re.sub(r'<[^>]+>', '', text)
|
| 65 |
+
else:
|
| 66 |
+
# replace angle brackets
|
| 67 |
+
text = re.sub(r'<', '(', text)
|
| 68 |
+
text = re.sub(r'>', ')', text)
|
| 69 |
+
# lower case
|
| 70 |
+
if lower:
|
| 71 |
+
text = text.lower()
|
| 72 |
+
# eliminate email addresses
|
| 73 |
+
if not keep_emails:
|
| 74 |
+
text = re.sub(r'\S+@\S+', ' ', text)
|
| 75 |
+
# eliminate @mentions
|
| 76 |
+
if not keep_at_mentions:
|
| 77 |
+
text = re.sub(r'\s@\S+', ' ', text)
|
| 78 |
+
# replace underscores with spaces
|
| 79 |
+
text = re.sub(r'_', ' ', text)
|
| 80 |
+
# break off single quotes at the ends of words
|
| 81 |
+
text = re.sub(r'\s\'', ' ', text)
|
| 82 |
+
text = re.sub(r'\'\s', ' ', text)
|
| 83 |
+
# remove periods
|
| 84 |
+
text = re.sub(r'\.', '', text)
|
| 85 |
+
# replace all other punctuation (except single quotes) with spaces
|
| 86 |
+
text = replace.sub(' ', text)
|
| 87 |
+
# remove single quotes
|
| 88 |
+
text = re.sub(r'\'', '', text)
|
| 89 |
+
# replace all whitespace with a single space
|
| 90 |
+
text = re.sub(r'\s', ' ', text)
|
| 91 |
+
# strip off spaces on either end
|
| 92 |
+
text = text.strip()
|
| 93 |
+
return text
|
| 94 |
+
|
| 95 |
+
def tokenize(self, text):
|
| 96 |
+
text = self.clean_text(text, self.strip_html, self.lower)
|
| 97 |
+
tokens = text.split()
|
| 98 |
+
|
| 99 |
+
tokens = ['_' if t in self.stopword_set else t for t in tokens]
|
| 100 |
+
|
| 101 |
+
# remove tokens that contain numbers
|
| 102 |
+
if not self.keep_alphanum and not self.keep_num:
|
| 103 |
+
tokens = [t if alpha.match(t) else '_' for t in tokens]
|
| 104 |
+
|
| 105 |
+
# or just remove tokens that contain a combination of letters and numbers
|
| 106 |
+
elif not self.keep_alphanum:
|
| 107 |
+
tokens = [t if alpha_or_num.match(t) else '_' for t in tokens]
|
| 108 |
+
|
| 109 |
+
# drop short tokens
|
| 110 |
+
if self.min_length > 0:
|
| 111 |
+
tokens = [t if len(t) >= self.min_length else '_' for t in tokens]
|
| 112 |
+
|
| 113 |
+
if getattr(self, "lemmatize", False):
|
| 114 |
+
tokens = [self.lemmatizer.lemmatize(t) if t != '_' else t for t in tokens]
|
| 115 |
+
|
| 116 |
+
unigrams = [t for t in tokens if t != '_']
|
| 117 |
+
return unigrams
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def make_word_embeddings(vocab):
|
| 121 |
+
glove_vectors = gensim.downloader.load('glove-wiki-gigaword-200')
|
| 122 |
+
word_embeddings = np.zeros((len(vocab), glove_vectors.vectors.shape[1]))
|
| 123 |
+
|
| 124 |
+
num_found = 0
|
| 125 |
+
|
| 126 |
+
try:
|
| 127 |
+
key_word_list = glove_vectors.index_to_key
|
| 128 |
+
except:
|
| 129 |
+
key_word_list = glove_vectors.index2word
|
| 130 |
+
|
| 131 |
+
for i, word in enumerate(tqdm(vocab, desc="loading word embeddings")):
|
| 132 |
+
if word in key_word_list:
|
| 133 |
+
word_embeddings[i] = glove_vectors[word]
|
| 134 |
+
num_found += 1
|
| 135 |
+
|
| 136 |
+
logger.info(f'number of found embeddings: {num_found}/{len(vocab)}')
|
| 137 |
+
|
| 138 |
+
return scipy.sparse.csr_matrix(word_embeddings)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class Preprocess:
|
| 142 |
+
def __init__(self,
|
| 143 |
+
tokenizer=None,
|
| 144 |
+
test_sample_size=None,
|
| 145 |
+
test_p=0.2,
|
| 146 |
+
stopwords="English",
|
| 147 |
+
min_doc_count=0,
|
| 148 |
+
max_doc_freq=1.0,
|
| 149 |
+
keep_num=False,
|
| 150 |
+
keep_alphanum=False,
|
| 151 |
+
strip_html=False,
|
| 152 |
+
no_lower=False,
|
| 153 |
+
min_length=3,
|
| 154 |
+
min_term=0,
|
| 155 |
+
vocab_size=None,
|
| 156 |
+
seed=42,
|
| 157 |
+
verbose=True,
|
| 158 |
+
lemmatize=True,
|
| 159 |
+
):
|
| 160 |
+
"""
|
| 161 |
+
Args:
|
| 162 |
+
test_sample_size:
|
| 163 |
+
Size of the test set.
|
| 164 |
+
test_p:
|
| 165 |
+
Proportion of the test set. This helps sample the train set based on the size of the test set.
|
| 166 |
+
stopwords:
|
| 167 |
+
List of stopwords to exclude.
|
| 168 |
+
min-doc-count:
|
| 169 |
+
Exclude words that occur in less than this number of documents.
|
| 170 |
+
max_doc_freq:
|
| 171 |
+
Exclude words that occur in more than this proportion of documents.
|
| 172 |
+
keep-num:
|
| 173 |
+
Keep tokens made of only numbers.
|
| 174 |
+
keep-alphanum:
|
| 175 |
+
Keep tokens made of a mixture of letters and numbers.
|
| 176 |
+
strip_html:
|
| 177 |
+
Strip HTML tags.
|
| 178 |
+
no-lower:
|
| 179 |
+
Do not lowercase text
|
| 180 |
+
min_length:
|
| 181 |
+
Minimum token length.
|
| 182 |
+
min_term:
|
| 183 |
+
Minimum term number
|
| 184 |
+
vocab-size:
|
| 185 |
+
Size of the vocabulary (by most common in the union of train and test sets, following above exclusions)
|
| 186 |
+
seed:
|
| 187 |
+
Random integer seed (only relevant for choosing test set)
|
| 188 |
+
lemmatize:
|
| 189 |
+
Whether to apply lemmatization to the tokens.
|
| 190 |
+
"""
|
| 191 |
+
|
| 192 |
+
self.test_sample_size = test_sample_size
|
| 193 |
+
self.min_doc_count = min_doc_count
|
| 194 |
+
self.max_doc_freq = max_doc_freq
|
| 195 |
+
self.min_term = min_term
|
| 196 |
+
self.test_p = test_p
|
| 197 |
+
self.vocab_size = vocab_size
|
| 198 |
+
self.seed = seed
|
| 199 |
+
|
| 200 |
+
if tokenizer is not None:
|
| 201 |
+
self.tokenizer = tokenizer
|
| 202 |
+
else:
|
| 203 |
+
self.tokenizer = Tokenizer(
|
| 204 |
+
stopwords,
|
| 205 |
+
keep_num,
|
| 206 |
+
keep_alphanum,
|
| 207 |
+
strip_html,
|
| 208 |
+
no_lower,
|
| 209 |
+
min_length,
|
| 210 |
+
lemmatize=lemmatize
|
| 211 |
+
).tokenize
|
| 212 |
+
|
| 213 |
+
if verbose:
|
| 214 |
+
logger.set_level("DEBUG")
|
| 215 |
+
else:
|
| 216 |
+
logger.set_level("WARNING")
|
| 217 |
+
|
| 218 |
+
def parse(self, texts, vocab):
|
| 219 |
+
if not isinstance(texts, list):
|
| 220 |
+
texts = [texts]
|
| 221 |
+
|
| 222 |
+
vocab_set = set(vocab)
|
| 223 |
+
parsed_texts = list()
|
| 224 |
+
for i, text in enumerate(tqdm(texts, desc="parsing texts")):
|
| 225 |
+
tokens = self.tokenizer(text)
|
| 226 |
+
tokens = [t for t in tokens if t in vocab_set]
|
| 227 |
+
parsed_texts.append(" ".join(tokens))
|
| 228 |
+
|
| 229 |
+
vectorizer = CountVectorizer(vocabulary=vocab, tokenizer=lambda x: x.split())
|
| 230 |
+
sparse_bow = vectorizer.fit_transform(parsed_texts)
|
| 231 |
+
return parsed_texts, sparse_bow
|
| 232 |
+
|
| 233 |
+
def preprocess_jsonlist(self, dataset_dir, label_name=None, use_partition=True):
|
| 234 |
+
if use_partition:
|
| 235 |
+
train_items = file_utils.read_jsonlist(os.path.join(dataset_dir, 'train.jsonlist'))
|
| 236 |
+
test_items = file_utils.read_jsonlist(os.path.join(dataset_dir, 'test.jsonlist'))
|
| 237 |
+
else:
|
| 238 |
+
raw_path = os.path.join(dataset_dir, 'docs.jsonl')
|
| 239 |
+
with open(raw_path, 'r', encoding='utf-8') as f:
|
| 240 |
+
train_items = [json.loads(line.strip()) for line in f if line.strip()]
|
| 241 |
+
test_items = []
|
| 242 |
+
|
| 243 |
+
logger.info(f"Found training documents {len(train_items)} testing documents {len(test_items)}")
|
| 244 |
+
|
| 245 |
+
# Initialize containers
|
| 246 |
+
raw_train_texts, train_labels, raw_train_times = [], [], []
|
| 247 |
+
raw_test_texts, test_labels, raw_test_times = [], [], []
|
| 248 |
+
|
| 249 |
+
# Process train items
|
| 250 |
+
for item in train_items:
|
| 251 |
+
raw_train_texts.append(item['text'])
|
| 252 |
+
raw_train_times.append(str(item['timestamp']))
|
| 253 |
+
if label_name and label_name in item:
|
| 254 |
+
train_labels.append(item[label_name])
|
| 255 |
+
|
| 256 |
+
# Process test items
|
| 257 |
+
for item in test_items:
|
| 258 |
+
raw_test_texts.append(item['text'])
|
| 259 |
+
raw_test_times.append(str(item['timestamp']))
|
| 260 |
+
if label_name and label_name in item:
|
| 261 |
+
test_labels.append(item[label_name])
|
| 262 |
+
|
| 263 |
+
# Create and apply time2id mapping
|
| 264 |
+
all_times = sorted(set(raw_train_times + raw_test_times))
|
| 265 |
+
time2id = {t: i for i, t in enumerate(all_times)}
|
| 266 |
+
|
| 267 |
+
train_times = np.array([time2id[t] for t in raw_train_times], dtype=np.int32)
|
| 268 |
+
test_times = np.array([time2id[t] for t in raw_test_times], dtype=np.int32) if raw_test_times else None
|
| 269 |
+
|
| 270 |
+
# Preprocess and get sample indices
|
| 271 |
+
rst = self.preprocess(raw_train_texts, train_labels, raw_test_texts, test_labels)
|
| 272 |
+
train_idx = rst.get("train_idx")
|
| 273 |
+
test_idx = rst.get("test_idx")
|
| 274 |
+
|
| 275 |
+
# Add filtered timestamps to result for saving later
|
| 276 |
+
rst["train_times"] = train_times[train_idx]
|
| 277 |
+
if test_times is not None and test_idx is not None:
|
| 278 |
+
rst["test_times"] = test_times[test_idx]
|
| 279 |
+
|
| 280 |
+
# Add time2id to result dict
|
| 281 |
+
rst["time2id"] = time2id
|
| 282 |
+
|
| 283 |
+
return rst
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def convert_labels(self, train_labels, test_labels):
|
| 287 |
+
if train_labels:
|
| 288 |
+
label_list = list(set(train_labels).union(set(test_labels)))
|
| 289 |
+
label_list.sort()
|
| 290 |
+
n_labels = len(label_list)
|
| 291 |
+
label2id = dict(zip(label_list, range(n_labels)))
|
| 292 |
+
|
| 293 |
+
logger.info(f"label2id: {label2id}")
|
| 294 |
+
|
| 295 |
+
train_labels = [label2id[label] for label in train_labels]
|
| 296 |
+
|
| 297 |
+
if test_labels:
|
| 298 |
+
test_labels = [label2id[label] for label in test_labels]
|
| 299 |
+
|
| 300 |
+
return train_labels, test_labels
|
| 301 |
+
|
| 302 |
+
def preprocess(
|
| 303 |
+
self,
|
| 304 |
+
raw_train_texts,
|
| 305 |
+
train_labels=None,
|
| 306 |
+
raw_test_texts=None,
|
| 307 |
+
test_labels=None,
|
| 308 |
+
pretrained_WE=True
|
| 309 |
+
):
|
| 310 |
+
np.random.seed(self.seed)
|
| 311 |
+
|
| 312 |
+
train_texts = list()
|
| 313 |
+
test_texts = list()
|
| 314 |
+
word_counts = Counter()
|
| 315 |
+
doc_counts_counter = Counter()
|
| 316 |
+
|
| 317 |
+
train_labels, test_labels = self.convert_labels(train_labels, test_labels)
|
| 318 |
+
|
| 319 |
+
for text in tqdm(raw_train_texts, desc="loading train texts"):
|
| 320 |
+
tokens = self.tokenizer(text)
|
| 321 |
+
word_counts.update(tokens)
|
| 322 |
+
doc_counts_counter.update(set(tokens))
|
| 323 |
+
parsed_text = ' '.join(tokens)
|
| 324 |
+
train_texts.append(parsed_text)
|
| 325 |
+
|
| 326 |
+
if raw_test_texts:
|
| 327 |
+
for text in tqdm(raw_test_texts, desc="loading test texts"):
|
| 328 |
+
tokens = self.tokenizer(text)
|
| 329 |
+
word_counts.update(tokens)
|
| 330 |
+
doc_counts_counter.update(set(tokens))
|
| 331 |
+
parsed_text = ' '.join(tokens)
|
| 332 |
+
test_texts.append(parsed_text)
|
| 333 |
+
|
| 334 |
+
words, doc_counts = zip(*doc_counts_counter.most_common())
|
| 335 |
+
doc_freqs = np.array(doc_counts) / float(len(train_texts) + len(test_texts))
|
| 336 |
+
|
| 337 |
+
vocab = [word for i, word in enumerate(words) if doc_counts[i] >= self.min_doc_count and doc_freqs[i] <= self.max_doc_freq]
|
| 338 |
+
|
| 339 |
+
# filter vocabulary
|
| 340 |
+
if self.vocab_size is not None:
|
| 341 |
+
vocab = vocab[:self.vocab_size]
|
| 342 |
+
|
| 343 |
+
vocab.sort()
|
| 344 |
+
|
| 345 |
+
train_idx = [i for i, text in enumerate(train_texts) if len(text.split()) >= self.min_term]
|
| 346 |
+
train_idx = np.asarray(train_idx)
|
| 347 |
+
|
| 348 |
+
if raw_test_texts is not None:
|
| 349 |
+
test_idx = [i for i, text in enumerate(test_texts) if len(text.split()) >= self.min_term]
|
| 350 |
+
test_idx = np.asarray(test_idx)
|
| 351 |
+
else:
|
| 352 |
+
test_idx = None
|
| 353 |
+
|
| 354 |
+
# randomly sample
|
| 355 |
+
if self.test_sample_size and raw_test_texts is not None:
|
| 356 |
+
logger.info("sample train and test sets...")
|
| 357 |
+
|
| 358 |
+
train_num = len(train_idx)
|
| 359 |
+
test_num = len(test_idx)
|
| 360 |
+
test_sample_size = min(test_num, self.test_sample_size)
|
| 361 |
+
train_sample_size = int((test_sample_size / self.test_p) * (1 - self.test_p))
|
| 362 |
+
if train_sample_size > train_num:
|
| 363 |
+
test_sample_size = int((train_num / (1 - self.test_p)) * self.test_p)
|
| 364 |
+
train_sample_size = train_num
|
| 365 |
+
|
| 366 |
+
train_idx = train_idx[np.sort(np.random.choice(train_num, train_sample_size, replace=False))]
|
| 367 |
+
test_idx = test_idx[np.sort(np.random.choice(test_num, test_sample_size, replace=False))]
|
| 368 |
+
|
| 369 |
+
logger.info(f"sampled train size: {len(train_idx)}")
|
| 370 |
+
logger.info(f"sampled test size: {len(test_idx)}")
|
| 371 |
+
|
| 372 |
+
train_texts, train_bow = self.parse([train_texts[i] for i in train_idx], vocab)
|
| 373 |
+
|
| 374 |
+
rst = {
|
| 375 |
+
'vocab': vocab,
|
| 376 |
+
'train_bow': train_bow,
|
| 377 |
+
"train_texts": train_texts,
|
| 378 |
+
"train_idx": train_idx, # <--- NEW: indices of kept train samples
|
| 379 |
+
}
|
| 380 |
+
|
| 381 |
+
if train_labels:
|
| 382 |
+
rst['train_labels'] = np.asarray(train_labels)[train_idx]
|
| 383 |
+
|
| 384 |
+
logger.info(f"Real vocab size: {len(vocab)}")
|
| 385 |
+
logger.info(f"Real training size: {len(train_texts)} \t avg length: {rst['train_bow'].sum() / len(train_texts):.3f}")
|
| 386 |
+
|
| 387 |
+
if raw_test_texts:
|
| 388 |
+
rst['test_texts'], rst['test_bow'] = self.parse(np.asarray(test_texts)[test_idx].tolist(), vocab)
|
| 389 |
+
rst["test_idx"] = test_idx # <--- NEW: indices of kept test samples
|
| 390 |
+
|
| 391 |
+
if test_labels:
|
| 392 |
+
rst['test_labels'] = np.asarray(test_labels)[test_idx]
|
| 393 |
+
|
| 394 |
+
logger.info(f"Real testing size: {len(rst['test_texts'])} \t avg length: {rst['test_bow'].sum() / len(rst['test_texts']):.3f}")
|
| 395 |
+
|
| 396 |
+
if pretrained_WE:
|
| 397 |
+
rst['word_embeddings'] = make_word_embeddings(vocab)
|
| 398 |
+
|
| 399 |
+
return rst
|
| 400 |
+
|
| 401 |
+
def save(
|
| 402 |
+
self,
|
| 403 |
+
output_dir,
|
| 404 |
+
vocab,
|
| 405 |
+
train_texts,
|
| 406 |
+
train_bow,
|
| 407 |
+
word_embeddings=None,
|
| 408 |
+
train_labels=None,
|
| 409 |
+
test_texts=None,
|
| 410 |
+
test_bow=None,
|
| 411 |
+
test_labels=None,
|
| 412 |
+
train_times=None,
|
| 413 |
+
test_times=None,
|
| 414 |
+
time2id=None # <-- new parameter
|
| 415 |
+
):
|
| 416 |
+
file_utils.make_dir(output_dir)
|
| 417 |
+
|
| 418 |
+
file_utils.save_text(vocab, f"{output_dir}/vocab.txt")
|
| 419 |
+
file_utils.save_text(train_texts, f"{output_dir}/train_texts.txt")
|
| 420 |
+
scipy.sparse.save_npz(f"{output_dir}/train_bow.npz", scipy.sparse.csr_matrix(train_bow))
|
| 421 |
+
|
| 422 |
+
if word_embeddings is not None:
|
| 423 |
+
scipy.sparse.save_npz(f"{output_dir}/word_embeddings.npz", word_embeddings)
|
| 424 |
+
|
| 425 |
+
if train_labels:
|
| 426 |
+
np.savetxt(f"{output_dir}/train_labels.txt", train_labels, fmt='%i')
|
| 427 |
+
|
| 428 |
+
if train_times is not None:
|
| 429 |
+
np.savetxt(f"{output_dir}/train_times.txt", train_times, fmt='%i')
|
| 430 |
+
|
| 431 |
+
if test_bow is not None:
|
| 432 |
+
scipy.sparse.save_npz(f"{output_dir}/test_bow.npz", scipy.sparse.csr_matrix(test_bow))
|
| 433 |
+
|
| 434 |
+
if test_texts is not None:
|
| 435 |
+
file_utils.save_text(test_texts, f"{output_dir}/test_texts.txt")
|
| 436 |
+
|
| 437 |
+
if test_labels:
|
| 438 |
+
np.savetxt(f"{output_dir}/test_labels.txt", test_labels, fmt='%i')
|
| 439 |
+
|
| 440 |
+
if test_times is not None:
|
| 441 |
+
np.savetxt(f"{output_dir}/test_times.txt", test_times, fmt='%i')
|
| 442 |
+
|
| 443 |
+
# Save time2id mapping if provided
|
| 444 |
+
if time2id is not None:
|
| 445 |
+
with open(f"{output_dir}/time2id.txt", "w", encoding="utf-8") as f:
|
| 446 |
+
json.dump(time2id, f, indent=2)
|
| 447 |
+
|
backend/datasets/data/download.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import zipfile
|
| 3 |
+
from torchvision.datasets.utils import download_url
|
| 4 |
+
from backend.datasets.utils.logger import Logger
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
logger = Logger("WARNING")
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def download_dataset(dataset_name, cache_path="~/.topmost"):
|
| 11 |
+
cache_path = os.path.expanduser(cache_path)
|
| 12 |
+
raw_filename = f'{dataset_name}.zip'
|
| 13 |
+
|
| 14 |
+
if dataset_name in ['Wikitext-103']:
|
| 15 |
+
# download from Git LFS.
|
| 16 |
+
zipped_dataset_url = f"https://media.githubusercontent.com/media/BobXWu/TopMost/main/data/{raw_filename}"
|
| 17 |
+
else:
|
| 18 |
+
zipped_dataset_url = f"https://raw.githubusercontent.com/BobXWu/TopMost/master/data/{raw_filename}"
|
| 19 |
+
|
| 20 |
+
logger.info(zipped_dataset_url)
|
| 21 |
+
|
| 22 |
+
download_url(zipped_dataset_url, root=cache_path, filename=raw_filename, md5=None)
|
| 23 |
+
|
| 24 |
+
path = f'{cache_path}/{raw_filename}'
|
| 25 |
+
with zipfile.ZipFile(path, 'r') as zip_ref:
|
| 26 |
+
zip_ref.extractall(cache_path)
|
| 27 |
+
|
| 28 |
+
os.remove(path)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
if __name__ == '__main__':
|
| 32 |
+
download_dataset('20NG')
|
backend/datasets/data/file_utils.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def make_dir(path):
|
| 6 |
+
os.makedirs(path, exist_ok=True)
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def read_text(path):
|
| 10 |
+
texts = list()
|
| 11 |
+
with open(path, 'r', encoding='utf-8', errors='ignore') as file:
|
| 12 |
+
for line in file:
|
| 13 |
+
texts.append(line.strip())
|
| 14 |
+
return texts
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def save_text(texts, path):
|
| 18 |
+
with open(path, 'w', encoding='utf-8') as file:
|
| 19 |
+
for text in texts:
|
| 20 |
+
file.write(text.strip() + '\n')
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def read_jsonlist(path):
|
| 24 |
+
data = list()
|
| 25 |
+
with open(path, 'r', encoding='utf-8') as input_file:
|
| 26 |
+
for line in input_file:
|
| 27 |
+
data.append(json.loads(line))
|
| 28 |
+
return data
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def save_jsonlist(list_of_json_objects, path, sort_keys=True):
|
| 32 |
+
with open(path, 'w', encoding='utf-8') as output_file:
|
| 33 |
+
for obj in list_of_json_objects:
|
| 34 |
+
output_file.write(json.dumps(obj, sort_keys=sort_keys) + '\n')
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def split_text_word(texts):
|
| 38 |
+
texts = [text.split() for text in texts]
|
| 39 |
+
return texts
|
backend/datasets/dynamic_dataset.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.utils.data import Dataset, DataLoader
|
| 3 |
+
import numpy as np
|
| 4 |
+
import scipy.sparse
|
| 5 |
+
import scipy.io
|
| 6 |
+
from backend.datasets.data import file_utils
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class _SequentialDataset(Dataset):
|
| 10 |
+
def __init__(self, bow, times, time_wordfreq):
|
| 11 |
+
super().__init__()
|
| 12 |
+
self.bow = bow
|
| 13 |
+
self.times = times
|
| 14 |
+
self.time_wordfreq = time_wordfreq
|
| 15 |
+
|
| 16 |
+
def __len__(self):
|
| 17 |
+
return len(self.bow)
|
| 18 |
+
|
| 19 |
+
def __getitem__(self, index):
|
| 20 |
+
return_dict = {
|
| 21 |
+
'bow': self.bow[index],
|
| 22 |
+
'times': self.times[index],
|
| 23 |
+
'time_wordfreq': self.time_wordfreq[self.times[index]],
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
return return_dict
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class DynamicDataset:
|
| 30 |
+
def __init__(self, dataset_dir, batch_size=200, read_labels=False, use_partition=False, device='cuda', as_tensor=True):
|
| 31 |
+
|
| 32 |
+
self.load_data(dataset_dir, read_labels, use_partition)
|
| 33 |
+
|
| 34 |
+
self.vocab_size = len(self.vocab)
|
| 35 |
+
self.train_size = len(self.train_bow)
|
| 36 |
+
self.num_times = int(self.train_times.max()) + 1 # assuming train_times is a numpy array
|
| 37 |
+
self.train_time_wordfreq = self.get_time_wordfreq(self.train_bow, self.train_times)
|
| 38 |
+
|
| 39 |
+
print('train size: ', len(self.train_bow))
|
| 40 |
+
if use_partition:
|
| 41 |
+
print('test size: ', len(self.test_bow))
|
| 42 |
+
print('vocab size: ', len(self.vocab))
|
| 43 |
+
print('average length: {:.3f}'.format(self.train_bow.sum(1).mean().item()))
|
| 44 |
+
print('num of each time slice: ', self.num_times, np.bincount(self.train_times))
|
| 45 |
+
|
| 46 |
+
if as_tensor:
|
| 47 |
+
self.train_bow = torch.from_numpy(self.train_bow).float().to(device)
|
| 48 |
+
self.train_times = torch.from_numpy(self.train_times).long().to(device)
|
| 49 |
+
self.train_time_wordfreq = torch.from_numpy(self.train_time_wordfreq).float().to(device)
|
| 50 |
+
|
| 51 |
+
if use_partition:
|
| 52 |
+
self.test_bow = torch.from_numpy(self.test_bow).float().to(device)
|
| 53 |
+
self.test_times = torch.from_numpy(self.test_times).long().to(device)
|
| 54 |
+
|
| 55 |
+
self.train_dataset = _SequentialDataset(self.train_bow, self.train_times, self.train_time_wordfreq)
|
| 56 |
+
|
| 57 |
+
if use_partition:
|
| 58 |
+
self.test_dataset = _SequentialDataset(self.test_bow, self.test_times, self.train_time_wordfreq)
|
| 59 |
+
|
| 60 |
+
self.train_dataloader = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True)
|
| 61 |
+
|
| 62 |
+
def load_data(self, path, read_labels, use_partition=False):
|
| 63 |
+
self.train_bow = scipy.sparse.load_npz(f'{path}/train_bow.npz').toarray().astype('float32')
|
| 64 |
+
self.train_texts = file_utils.read_text(f'{path}/train_texts.txt')
|
| 65 |
+
self.train_times = np.loadtxt(f'{path}/train_times.txt').astype('int32')
|
| 66 |
+
self.vocab = file_utils.read_text(f'{path}/vocab.txt')
|
| 67 |
+
self.word_embeddings = scipy.sparse.load_npz(f'{path}/word_embeddings.npz').toarray().astype('float32')
|
| 68 |
+
|
| 69 |
+
self.pretrained_WE = self.word_embeddings # preserve compatibility
|
| 70 |
+
|
| 71 |
+
if read_labels:
|
| 72 |
+
self.train_labels = np.loadtxt(f'{path}/train_labels.txt').astype('int32')
|
| 73 |
+
|
| 74 |
+
if use_partition:
|
| 75 |
+
self.test_bow = scipy.sparse.load_npz(f'{path}/test_bow.npz').toarray().astype('float32')
|
| 76 |
+
self.test_texts = file_utils.read_text(f'{path}/test_texts.txt')
|
| 77 |
+
self.test_times = np.loadtxt(f'{path}/test_times.txt').astype('int32')
|
| 78 |
+
if read_labels:
|
| 79 |
+
self.test_labels = np.loadtxt(f'{path}/test_labels.txt').astype('int32')
|
| 80 |
+
|
| 81 |
+
# word frequency at each time slice.
|
| 82 |
+
def get_time_wordfreq(self, bow, times):
|
| 83 |
+
train_time_wordfreq = np.zeros((self.num_times, self.vocab_size))
|
| 84 |
+
for time in range(self.num_times):
|
| 85 |
+
idx = np.where(times == time)[0]
|
| 86 |
+
train_time_wordfreq[time] += bow[idx].sum(0)
|
| 87 |
+
cnt_times = np.bincount(times)
|
| 88 |
+
|
| 89 |
+
train_time_wordfreq = train_time_wordfreq / cnt_times[:, np.newaxis]
|
| 90 |
+
return train_time_wordfreq
|
backend/datasets/preprocess.py
ADDED
|
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import numpy as np
|
| 4 |
+
from collections import OrderedDict
|
| 5 |
+
import tempfile
|
| 6 |
+
import gensim.downloader
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
from backend.datasets.utils.logger import Logger
|
| 9 |
+
import scipy.sparse
|
| 10 |
+
from gensim.models.phrases import Phrases, Phraser
|
| 11 |
+
from typing import List, Union
|
| 12 |
+
from octis.preprocessing.preprocessing import Preprocessing
|
| 13 |
+
|
| 14 |
+
logger = Logger("WARNING")
|
| 15 |
+
|
| 16 |
+
class Preprocessor:
|
| 17 |
+
def __init__(self,
|
| 18 |
+
docs_jsonl_path: str,
|
| 19 |
+
output_folder: str,
|
| 20 |
+
use_partition: bool = False,
|
| 21 |
+
use_bigrams: bool = False,
|
| 22 |
+
min_count_bigram: int = 5,
|
| 23 |
+
threshold_bigram: int = 10,
|
| 24 |
+
remove_punctuation: bool = True,
|
| 25 |
+
lemmatize: bool = True,
|
| 26 |
+
stopword_list: Union[str, List[str]] = None,
|
| 27 |
+
min_chars: int = 3,
|
| 28 |
+
min_words_docs: int = 10,
|
| 29 |
+
min_df: Union[int, float] = 0.0,
|
| 30 |
+
max_df: Union[int, float] = 1.0,
|
| 31 |
+
max_features: int = None,
|
| 32 |
+
language: str = 'english'):
|
| 33 |
+
|
| 34 |
+
self.docs_jsonl_path = docs_jsonl_path
|
| 35 |
+
self.output_folder = output_folder
|
| 36 |
+
self.use_partition = use_partition
|
| 37 |
+
self.use_bigrams = use_bigrams
|
| 38 |
+
self.min_count_bigram = min_count_bigram
|
| 39 |
+
self.threshold_bigram = threshold_bigram
|
| 40 |
+
|
| 41 |
+
os.makedirs(self.output_folder, exist_ok=True)
|
| 42 |
+
|
| 43 |
+
self.preprocessing_params = {
|
| 44 |
+
'remove_punctuation': remove_punctuation,
|
| 45 |
+
'lemmatize': lemmatize,
|
| 46 |
+
'stopword_list': stopword_list,
|
| 47 |
+
'min_chars': min_chars,
|
| 48 |
+
'min_words_docs': min_words_docs,
|
| 49 |
+
'min_df': min_df,
|
| 50 |
+
'max_df': max_df,
|
| 51 |
+
'max_features': max_features,
|
| 52 |
+
'language': language
|
| 53 |
+
}
|
| 54 |
+
self.preprocessor_octis = Preprocessing(**self.preprocessing_params)
|
| 55 |
+
|
| 56 |
+
def _load_data_to_temp_files(self):
|
| 57 |
+
"""Loads data from JSONL and writes to temporary files for OCTIS preprocessor."""
|
| 58 |
+
raw_texts = []
|
| 59 |
+
raw_timestamps = []
|
| 60 |
+
raw_labels = []
|
| 61 |
+
has_labels = False
|
| 62 |
+
|
| 63 |
+
with open(self.docs_jsonl_path, 'r', encoding='utf-8') as f:
|
| 64 |
+
for line in f:
|
| 65 |
+
data = json.loads(line.strip())
|
| 66 |
+
# Remove newlines from text
|
| 67 |
+
clean_text = data.get('text', '').replace('\n', ' ').replace('\r', ' ')
|
| 68 |
+
clean_text = " ".join(clean_text.split())
|
| 69 |
+
raw_texts.append(clean_text)
|
| 70 |
+
raw_timestamps.append(data.get('timestamp', ''))
|
| 71 |
+
label = data.get('label', '')
|
| 72 |
+
if label:
|
| 73 |
+
has_labels = True
|
| 74 |
+
raw_labels.append(label)
|
| 75 |
+
|
| 76 |
+
# Create temporary files
|
| 77 |
+
temp_dir = tempfile.mkdtemp()
|
| 78 |
+
temp_docs_path = os.path.join(temp_dir, "temp_docs.txt")
|
| 79 |
+
temp_labels_path = None
|
| 80 |
+
|
| 81 |
+
with open(temp_docs_path, 'w', encoding='utf-8') as f_docs:
|
| 82 |
+
for text in raw_texts:
|
| 83 |
+
f_docs.write(f"{text}\n")
|
| 84 |
+
|
| 85 |
+
if has_labels:
|
| 86 |
+
temp_labels_path = os.path.join(temp_dir, "temp_labels.txt")
|
| 87 |
+
with open(temp_labels_path, 'w', encoding='utf-8') as f_labels:
|
| 88 |
+
for label in raw_labels:
|
| 89 |
+
f_labels.write(f"{label}\n")
|
| 90 |
+
|
| 91 |
+
print(f"Loaded {len(raw_texts)} raw documents and created temporary files in {temp_dir}.")
|
| 92 |
+
return raw_texts, raw_timestamps, raw_labels, temp_docs_path, temp_labels_path, temp_dir
|
| 93 |
+
|
| 94 |
+
def _make_word_embeddings(self, vocab):
|
| 95 |
+
"""
|
| 96 |
+
Generates word embeddings for the given vocabulary using GloVe.
|
| 97 |
+
For n-grams (e.g., "wordA_wordB", "wordX_wordY_wordZ" for n>=2),
|
| 98 |
+
the resultant embedding is the sum of the embeddings of its constituent
|
| 99 |
+
single words (wordA + wordB + ...).
|
| 100 |
+
"""
|
| 101 |
+
print("Loading GloVe word embeddings...")
|
| 102 |
+
glove_vectors = gensim.downloader.load('glove-wiki-gigaword-200')
|
| 103 |
+
|
| 104 |
+
# Initialize word_embeddings matrix with zeros.
|
| 105 |
+
# This ensures that words not found (single or n-gram constituents)
|
| 106 |
+
# will have a zero vector embedding.
|
| 107 |
+
word_embeddings = np.zeros((len(vocab), glove_vectors.vectors.shape[1]), dtype=np.float32)
|
| 108 |
+
|
| 109 |
+
num_found = 0
|
| 110 |
+
|
| 111 |
+
try:
|
| 112 |
+
# Using a set for key_word_list for O(1) average time complexity lookup
|
| 113 |
+
key_word_list = set(glove_vectors.index_to_key)
|
| 114 |
+
except AttributeError: # For older gensim versions
|
| 115 |
+
key_word_list = set(glove_vectors.index2word)
|
| 116 |
+
|
| 117 |
+
print("Generating word embeddings for vocabulary (including n-grams)...")
|
| 118 |
+
for i, word in enumerate(tqdm(vocab, desc="Processing vocabulary words")):
|
| 119 |
+
if '_' in word: # Check if it's a potential n-gram (n >= 2)
|
| 120 |
+
parts = word.split('_')
|
| 121 |
+
|
| 122 |
+
# Check if *all* constituent words are present in GloVe
|
| 123 |
+
all_parts_in_glove = True
|
| 124 |
+
for part in parts:
|
| 125 |
+
if part not in key_word_list:
|
| 126 |
+
all_parts_in_glove = False
|
| 127 |
+
break # One part not found, stop checking
|
| 128 |
+
|
| 129 |
+
if all_parts_in_glove:
|
| 130 |
+
# If all parts are found, sum their embeddings
|
| 131 |
+
resultant_vector = np.zeros(glove_vectors.vectors.shape[1], dtype=np.float32)
|
| 132 |
+
for part in parts:
|
| 133 |
+
resultant_vector += glove_vectors[part]
|
| 134 |
+
|
| 135 |
+
word_embeddings[i] = resultant_vector
|
| 136 |
+
num_found += 1
|
| 137 |
+
# Else: one or more constituent words not found, embedding remains zero
|
| 138 |
+
else: # It's a single word (n=1)
|
| 139 |
+
if word in key_word_list:
|
| 140 |
+
word_embeddings[i] = glove_vectors[word]
|
| 141 |
+
num_found += 1
|
| 142 |
+
# Else: single word not found, embedding remains zero
|
| 143 |
+
|
| 144 |
+
logger.info(f'Number of found embeddings (including n-grams): {num_found}/{len(vocab)}')
|
| 145 |
+
return word_embeddings # Return as dense NumPy array
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def _save_doc_length_stats(self, filepath: str, output_path: str):
|
| 149 |
+
doc_lengths = []
|
| 150 |
+
try:
|
| 151 |
+
with open(filepath, 'r', encoding='utf-8') as f:
|
| 152 |
+
for line in f:
|
| 153 |
+
doc = line.strip()
|
| 154 |
+
if doc:
|
| 155 |
+
doc_lengths.append(len(doc))
|
| 156 |
+
except Exception as e:
|
| 157 |
+
print(f"Error processing '{filepath}': {e}")
|
| 158 |
+
return
|
| 159 |
+
|
| 160 |
+
if not doc_lengths:
|
| 161 |
+
print(f"No documents found in '{filepath}'.")
|
| 162 |
+
return
|
| 163 |
+
|
| 164 |
+
stats = {
|
| 165 |
+
"avg_len": float(np.mean(doc_lengths)),
|
| 166 |
+
"std_len": float(np.std(doc_lengths)),
|
| 167 |
+
"max_len": int(np.max(doc_lengths)),
|
| 168 |
+
"min_len": int(np.min(doc_lengths)),
|
| 169 |
+
"num_docs": int(len(doc_lengths))
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
with open(output_path, 'w', encoding='utf-8') as f:
|
| 173 |
+
json.dump(stats, f, indent=4)
|
| 174 |
+
print(f"Saved document length stats to: {output_path}")
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def preprocess(self):
|
| 178 |
+
print("Loading data and creating temporary files for OCTIS...")
|
| 179 |
+
_, raw_timestamps, _, temp_docs_path, temp_labels_path, temp_dir = \
|
| 180 |
+
self._load_data_to_temp_files()
|
| 181 |
+
|
| 182 |
+
print("Starting OCTIS pre-processing using file paths and specified parameters...")
|
| 183 |
+
octis_dataset = self.preprocessor_octis.preprocess_dataset(
|
| 184 |
+
documents_path=temp_docs_path,
|
| 185 |
+
labels_path=temp_labels_path
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
# Clean up temporary files immediately
|
| 189 |
+
os.remove(temp_docs_path)
|
| 190 |
+
if temp_labels_path:
|
| 191 |
+
os.remove(temp_labels_path)
|
| 192 |
+
os.rmdir(temp_dir)
|
| 193 |
+
print(f"Temporary files in {temp_dir} cleaned up.")
|
| 194 |
+
|
| 195 |
+
# --- Proxy: Save __original_indexes and then manually load it ---
|
| 196 |
+
temp_indexes_dir = tempfile.mkdtemp()
|
| 197 |
+
temp_indexes_file = os.path.join(temp_indexes_dir, "temp_original_indexes.txt")
|
| 198 |
+
|
| 199 |
+
print(f"Saving __original_indexes to {temp_indexes_file}...")
|
| 200 |
+
octis_dataset._save_document_indexes(temp_indexes_file)
|
| 201 |
+
|
| 202 |
+
# Manually load the indexes from the file
|
| 203 |
+
original_indexes_after_octis = []
|
| 204 |
+
with open(temp_indexes_file, 'r') as f_indexes:
|
| 205 |
+
for line in f_indexes:
|
| 206 |
+
original_indexes_after_octis.append(int(line.strip())) # Read as int
|
| 207 |
+
|
| 208 |
+
# Clean up the temporary indexes file and its directory
|
| 209 |
+
os.remove(temp_indexes_file)
|
| 210 |
+
os.rmdir(temp_indexes_dir)
|
| 211 |
+
print("Temporary indexes file cleaned up.")
|
| 212 |
+
# --- End Proxy ---
|
| 213 |
+
|
| 214 |
+
# Get processed data from OCTIS Dataset object
|
| 215 |
+
processed_corpus_octis_list = octis_dataset.get_corpus() # List of list of tokens
|
| 216 |
+
processed_labels_octis = octis_dataset.get_labels() # List of labels
|
| 217 |
+
|
| 218 |
+
print("Max index in original_indexes_after_octis:", max(original_indexes_after_octis))
|
| 219 |
+
print("Length of raw_timestamps:", len(raw_timestamps))
|
| 220 |
+
|
| 221 |
+
# Filter timestamps based on documents that survived OCTIS preprocessing
|
| 222 |
+
filtered_timestamps = [raw_timestamps[i] for i in original_indexes_after_octis]
|
| 223 |
+
|
| 224 |
+
print(f"OCTIS preprocessing complete. {len(processed_corpus_octis_list)} documents remaining.")
|
| 225 |
+
|
| 226 |
+
if self.use_bigrams:
|
| 227 |
+
print("Generating bigrams with Gensim...")
|
| 228 |
+
phrases = Phrases(processed_corpus_octis_list, min_count=self.min_count_bigram, threshold=self.threshold_bigram)
|
| 229 |
+
bigram_phraser = Phraser(phrases)
|
| 230 |
+
bigrammed_corpus_list = [bigram_phraser[doc] for doc in processed_corpus_octis_list]
|
| 231 |
+
print("Bigram generation complete.")
|
| 232 |
+
else:
|
| 233 |
+
print("Skipping bigram generation as 'use_bigrams' is False.")
|
| 234 |
+
bigrammed_corpus_list = processed_corpus_octis_list # Use the original processed list
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
# Convert back to list of strings for easier handling if needed later, but keep as list of lists for BOW
|
| 238 |
+
bigrammed_texts_for_file = [" ".join(doc) for doc in bigrammed_corpus_list]
|
| 239 |
+
print("Bigram generation complete.")
|
| 240 |
+
|
| 241 |
+
# Build Vocabulary from OCTIS output (after bigrams)
|
| 242 |
+
# We need a flat list of all tokens to build the vocabulary
|
| 243 |
+
all_tokens = [token for doc in bigrammed_corpus_list for token in doc]
|
| 244 |
+
vocab = sorted(list(set(all_tokens))) # Sorted unique words form the vocabulary
|
| 245 |
+
word_to_id = {word: i for i, word in enumerate(vocab)}
|
| 246 |
+
|
| 247 |
+
# Create BOW matrix manually
|
| 248 |
+
print("Creating Bag-of-Words representations...")
|
| 249 |
+
rows, cols, data = [], [], []
|
| 250 |
+
for i, doc_tokens in enumerate(bigrammed_corpus_list):
|
| 251 |
+
doc_word_counts = {}
|
| 252 |
+
for token in doc_tokens:
|
| 253 |
+
if token in word_to_id: # Ensure token is in our final vocab
|
| 254 |
+
doc_word_counts[word_to_id[token]] = doc_word_counts.get(word_to_id[token], 0) + 1
|
| 255 |
+
for col_id, count in doc_word_counts.items():
|
| 256 |
+
rows.append(i)
|
| 257 |
+
cols.append(col_id)
|
| 258 |
+
data.append(count)
|
| 259 |
+
|
| 260 |
+
# Shape is (num_documents, vocab_size)
|
| 261 |
+
bow_matrix = scipy.sparse.csc_matrix((data, (rows, cols)), shape=(len(bigrammed_corpus_list), len(vocab)))
|
| 262 |
+
print("Bag-of-Words complete.")
|
| 263 |
+
|
| 264 |
+
# Handle partitioning if required
|
| 265 |
+
if self.use_partition:
|
| 266 |
+
num_docs = len(bigrammed_corpus_list)
|
| 267 |
+
train_size = int(0.8 * num_docs)
|
| 268 |
+
|
| 269 |
+
train_texts = bigrammed_texts_for_file[:train_size]
|
| 270 |
+
train_bow_matrix = bow_matrix[:train_size]
|
| 271 |
+
train_timestamps = filtered_timestamps[:train_size]
|
| 272 |
+
train_labels = processed_labels_octis[:train_size] if processed_labels_octis else []
|
| 273 |
+
|
| 274 |
+
test_texts = bigrammed_texts_for_file[train_size:]
|
| 275 |
+
test_bow_matrix = bow_matrix[train_size:]
|
| 276 |
+
test_timestamps = filtered_timestamps[train_size:]
|
| 277 |
+
test_labels = processed_labels_octis[train_size:] if processed_labels_octis else []
|
| 278 |
+
|
| 279 |
+
else:
|
| 280 |
+
train_texts = bigrammed_texts_for_file
|
| 281 |
+
train_bow_matrix = bow_matrix
|
| 282 |
+
train_timestamps = filtered_timestamps
|
| 283 |
+
train_labels = processed_labels_octis
|
| 284 |
+
test_texts = []
|
| 285 |
+
test_timestamps = []
|
| 286 |
+
test_labels = []
|
| 287 |
+
|
| 288 |
+
# Generate word embeddings using the provided function
|
| 289 |
+
word_embeddings = self._make_word_embeddings(vocab)
|
| 290 |
+
|
| 291 |
+
# Process timestamps to 0, 1, 2...T and create time2id.txt
|
| 292 |
+
print("Processing timestamps...")
|
| 293 |
+
unique_timestamps = sorted(list(set(train_timestamps + test_timestamps)))
|
| 294 |
+
time_to_id = {timestamp: i for i, timestamp in enumerate(unique_timestamps)}
|
| 295 |
+
|
| 296 |
+
train_times_ids = [time_to_id[ts] for ts in train_timestamps]
|
| 297 |
+
test_times_ids = [time_to_id[ts] for ts in test_timestamps] if self.use_partition else []
|
| 298 |
+
print("Timestamps processed.")
|
| 299 |
+
|
| 300 |
+
# Save files
|
| 301 |
+
print(f"Saving preprocessed files to {self.output_folder}...")
|
| 302 |
+
|
| 303 |
+
# 1. vocab.txt
|
| 304 |
+
with open(os.path.join(self.output_folder, "vocab.txt"), "w", encoding="utf-8") as f:
|
| 305 |
+
for word in vocab:
|
| 306 |
+
f.write(f"{word}\n")
|
| 307 |
+
|
| 308 |
+
# 2. train_texts.txt
|
| 309 |
+
train_text_path = os.path.join(self.output_folder, "train_texts.txt")
|
| 310 |
+
with open(train_text_path, "w", encoding="utf-8") as f:
|
| 311 |
+
for doc in train_texts:
|
| 312 |
+
f.write(f"{doc}\n")
|
| 313 |
+
|
| 314 |
+
# Save document length stats
|
| 315 |
+
doc_stats_path = os.path.join(self.output_folder, "length_stats.json")
|
| 316 |
+
self._save_doc_length_stats(train_text_path, doc_stats_path)
|
| 317 |
+
|
| 318 |
+
# 3. train_bow.npz
|
| 319 |
+
scipy.sparse.save_npz(os.path.join(self.output_folder, "train_bow.npz"), train_bow_matrix)
|
| 320 |
+
|
| 321 |
+
# 4. word_embeddings.npz
|
| 322 |
+
sparse_word_embeddings = scipy.sparse.csr_matrix(word_embeddings)
|
| 323 |
+
scipy.sparse.save_npz(os.path.join(self.output_folder, "word_embeddings.npz"), sparse_word_embeddings)
|
| 324 |
+
|
| 325 |
+
# 5. train_labels.txt (if labels exist)
|
| 326 |
+
if train_labels:
|
| 327 |
+
with open(os.path.join(self.output_folder, "train_labels.txt"), "w", encoding="utf-8") as f:
|
| 328 |
+
for label in train_labels:
|
| 329 |
+
f.write(f"{label}\n")
|
| 330 |
+
|
| 331 |
+
# 6. train_times.txt
|
| 332 |
+
with open(os.path.join(self.output_folder, "train_times.txt"), "w", encoding="utf-8") as f:
|
| 333 |
+
for time_id in train_times_ids:
|
| 334 |
+
f.write(f"{time_id}\n")
|
| 335 |
+
|
| 336 |
+
# Files for test set (if use_partition=True)
|
| 337 |
+
if self.use_partition:
|
| 338 |
+
# 7. test_bow.npz
|
| 339 |
+
scipy.sparse.save_npz(os.path.join(self.output_folder, "test_bow.npz"), test_bow_matrix)
|
| 340 |
+
|
| 341 |
+
# 8. test_texts.txt
|
| 342 |
+
with open(os.path.join(self.output_folder, "test_texts.txt"), "w", encoding="utf-8") as f:
|
| 343 |
+
for doc in test_texts:
|
| 344 |
+
f.write(f"{doc}\n")
|
| 345 |
+
|
| 346 |
+
# 9. test_labels.txt (if labels exist)
|
| 347 |
+
if test_labels:
|
| 348 |
+
with open(os.path.join(self.output_folder, "test_labels.txt"), "w", encoding="utf-8") as f:
|
| 349 |
+
for label in test_labels:
|
| 350 |
+
f.write(f"{label}\n")
|
| 351 |
+
|
| 352 |
+
# 10. test_times.txt
|
| 353 |
+
with open(os.path.join(self.output_folder, "test_times.txt"), "w", encoding="utf-8") as f:
|
| 354 |
+
for time_id in test_times_ids:
|
| 355 |
+
f.write(f"{time_id}\n")
|
| 356 |
+
|
| 357 |
+
# 11. time2id.txt
|
| 358 |
+
sorted_time_to_id = OrderedDict(sorted(time_to_id.items(), key=lambda item: item[1]))
|
| 359 |
+
with open(os.path.join(self.output_folder, "time2id.txt"), "w", encoding="utf-8") as f:
|
| 360 |
+
json.dump(sorted_time_to_id, f, indent=4)
|
| 361 |
+
|
| 362 |
+
print("All files saved successfully.")
|
backend/datasets/utils/_utils.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from backend.datasets.data import file_utils
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def get_top_words(beta, vocab, num_top_words, verbose=False):
|
| 6 |
+
topic_str_list = list()
|
| 7 |
+
for i, topic_dist in enumerate(beta):
|
| 8 |
+
topic_words = np.array(vocab)[np.argsort(topic_dist)][:-(num_top_words + 1):-1]
|
| 9 |
+
topic_str = ' '.join(topic_words)
|
| 10 |
+
topic_str_list.append(topic_str)
|
| 11 |
+
if verbose:
|
| 12 |
+
print('Topic {}: {}'.format(i, topic_str))
|
| 13 |
+
|
| 14 |
+
return topic_str_list
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def get_stopwords_set(stopwords=[]):
|
| 18 |
+
from backend.datasets.data.download import download_dataset
|
| 19 |
+
|
| 20 |
+
if stopwords == 'English':
|
| 21 |
+
from gensim.parsing.preprocessing import STOPWORDS as stopwords
|
| 22 |
+
|
| 23 |
+
elif stopwords in ['mallet', 'snowball']:
|
| 24 |
+
download_dataset('stopwords', cache_path='./')
|
| 25 |
+
path = f'./stopwords/{stopwords}_stopwords.txt'
|
| 26 |
+
stopwords = file_utils.read_text(path)
|
| 27 |
+
|
| 28 |
+
stopword_set = frozenset(stopwords)
|
| 29 |
+
|
| 30 |
+
return stopword_set
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
if __name__ == '__main__':
|
| 34 |
+
print(list(get_stopwords_set('English'))[:10])
|
| 35 |
+
print(list(get_stopwords_set('mallet'))[:10])
|
| 36 |
+
print(list(get_stopwords_set('snowball'))[:10])
|
| 37 |
+
print(list(get_stopwords_set())[:10])
|
backend/datasets/utils/logger.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class Logger:
|
| 5 |
+
def __init__(self, level):
|
| 6 |
+
self.logger = logging.getLogger('TopMost')
|
| 7 |
+
self.set_level(level)
|
| 8 |
+
self._add_handler()
|
| 9 |
+
self.logger.propagate = False
|
| 10 |
+
|
| 11 |
+
def info(self, message):
|
| 12 |
+
self.logger.info(f"{message}")
|
| 13 |
+
|
| 14 |
+
def warning(self, message):
|
| 15 |
+
self.logger.warning(f"WARNING: {message}")
|
| 16 |
+
|
| 17 |
+
def set_level(self, level):
|
| 18 |
+
levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
|
| 19 |
+
if level in levels:
|
| 20 |
+
self.logger.setLevel(level)
|
| 21 |
+
|
| 22 |
+
def _add_handler(self):
|
| 23 |
+
sh = logging.StreamHandler()
|
| 24 |
+
sh.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(message)s'))
|
| 25 |
+
self.logger.addHandler(sh)
|
| 26 |
+
|
| 27 |
+
# Remove duplicate handlers
|
| 28 |
+
if len(self.logger.handlers) > 1:
|
| 29 |
+
self.logger.handlers = [self.logger.handlers[0]]
|
backend/evaluation/CoherenceModel_ttc.py
ADDED
|
@@ -0,0 +1,862 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import multiprocessing as mp
|
| 3 |
+
from collections import namedtuple
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
from gensim import interfaces, matutils
|
| 8 |
+
from gensim import utils
|
| 9 |
+
from gensim.topic_coherence import (
|
| 10 |
+
segmentation, probability_estimation,
|
| 11 |
+
direct_confirmation_measure, indirect_confirmation_measure,
|
| 12 |
+
aggregation,
|
| 13 |
+
)
|
| 14 |
+
from gensim.topic_coherence.probability_estimation import unique_ids_from_segments
|
| 15 |
+
|
| 16 |
+
# Set up logging for this module
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
# Define sets for categorizing coherence measures based on their probability estimation method
|
| 20 |
+
BOOLEAN_DOCUMENT_BASED = {'u_mass'}
|
| 21 |
+
SLIDING_WINDOW_BASED = {'c_v', 'c_uci', 'c_npmi', 'c_w2v'}
|
| 22 |
+
|
| 23 |
+
# Create a namedtuple to define the structure of a coherence measure pipeline
|
| 24 |
+
# Each pipeline consists of a segmentation (seg), probability estimation (prob),
|
| 25 |
+
# confirmation measure (conf), and aggregation (aggr) function.
|
| 26 |
+
_make_pipeline = namedtuple('Coherence_Measure', 'seg, prob, conf, aggr')
|
| 27 |
+
|
| 28 |
+
# Define the supported coherence measures and their respective pipeline components
|
| 29 |
+
COHERENCE_MEASURES = {
|
| 30 |
+
'u_mass': _make_pipeline(
|
| 31 |
+
segmentation.s_one_pre,
|
| 32 |
+
probability_estimation.p_boolean_document,
|
| 33 |
+
direct_confirmation_measure.log_conditional_probability,
|
| 34 |
+
aggregation.arithmetic_mean
|
| 35 |
+
),
|
| 36 |
+
'c_v': _make_pipeline(
|
| 37 |
+
segmentation.s_one_set,
|
| 38 |
+
probability_estimation.p_boolean_sliding_window,
|
| 39 |
+
indirect_confirmation_measure.cosine_similarity,
|
| 40 |
+
aggregation.arithmetic_mean
|
| 41 |
+
),
|
| 42 |
+
'c_w2v': _make_pipeline(
|
| 43 |
+
segmentation.s_one_set,
|
| 44 |
+
probability_estimation.p_word2vec,
|
| 45 |
+
indirect_confirmation_measure.word2vec_similarity,
|
| 46 |
+
aggregation.arithmetic_mean
|
| 47 |
+
),
|
| 48 |
+
'c_uci': _make_pipeline(
|
| 49 |
+
segmentation.s_one_one,
|
| 50 |
+
probability_estimation.p_boolean_sliding_window,
|
| 51 |
+
direct_confirmation_measure.log_ratio_measure,
|
| 52 |
+
aggregation.arithmetic_mean
|
| 53 |
+
),
|
| 54 |
+
'c_npmi': _make_pipeline(
|
| 55 |
+
segmentation.s_one_one,
|
| 56 |
+
probability_estimation.p_boolean_sliding_window,
|
| 57 |
+
direct_confirmation_measure.log_ratio_measure,
|
| 58 |
+
aggregation.arithmetic_mean
|
| 59 |
+
),
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
# Define default sliding window sizes for different coherence measures
|
| 63 |
+
SLIDING_WINDOW_SIZES = {
|
| 64 |
+
'c_v': 110,
|
| 65 |
+
'c_w2v': 5,
|
| 66 |
+
'c_uci': 10,
|
| 67 |
+
'c_npmi': 10,
|
| 68 |
+
'u_mass': None # u_mass does not use a sliding window
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class CoherenceModel_ttc(interfaces.TransformationABC):
|
| 73 |
+
"""Objects of this class allow for building and maintaining a model for topic coherence.
|
| 74 |
+
|
| 75 |
+
Examples
|
| 76 |
+
---------
|
| 77 |
+
One way of using this feature is through providing a trained topic model. A dictionary has to be explicitly provided
|
| 78 |
+
if the model does not contain a dictionary already
|
| 79 |
+
|
| 80 |
+
.. sourcecode:: pycon
|
| 81 |
+
|
| 82 |
+
>>> from gensim.test.utils import common_corpus, common_dictionary
|
| 83 |
+
>>> from gensim.models.ldamodel import LdaModel
|
| 84 |
+
>>> # Assuming CoherenceModel_ttc is imported or defined in the current scope
|
| 85 |
+
>>> # from your_module import CoherenceModel_ttc # if saved in a file
|
| 86 |
+
>>>
|
| 87 |
+
>>> model = LdaModel(common_corpus, 5, common_dictionary)
|
| 88 |
+
>>>
|
| 89 |
+
>>> cm = CoherenceModel_ttc(model=model, corpus=common_corpus, coherence='u_mass')
|
| 90 |
+
>>> coherence = cm.get_coherence() # get coherence value
|
| 91 |
+
|
| 92 |
+
Another way of using this feature is through providing tokenized topics such as:
|
| 93 |
+
|
| 94 |
+
.. sourcecode:: pycon
|
| 95 |
+
|
| 96 |
+
>>> from gensim.test.utils import common_corpus, common_dictionary
|
| 97 |
+
>>> # Assuming CoherenceModel_ttc is imported or defined in the current scope
|
| 98 |
+
>>> # from your_module import CoherenceModel_ttc # if saved in a file
|
| 99 |
+
>>> topics = [
|
| 100 |
+
... ['human', 'computer', 'system', 'interface'],
|
| 101 |
+
... ['graph', 'minors', 'trees', 'eps']
|
| 102 |
+
... ]
|
| 103 |
+
>>>
|
| 104 |
+
>>> cm = CoherenceModel_ttc(topics=topics, corpus=common_corpus, dictionary=common_dictionary, coherence='u_mass')
|
| 105 |
+
>>> coherence = cm.get_coherence() # get coherence value
|
| 106 |
+
|
| 107 |
+
"""
|
| 108 |
+
def __init__(self, model=None, topics=None, texts=None, corpus=None, dictionary=None,
|
| 109 |
+
window_size=None, keyed_vectors=None, coherence='c_v', topn=20, processes=-1):
|
| 110 |
+
"""
|
| 111 |
+
Initializes the CoherenceModel_ttc.
|
| 112 |
+
|
| 113 |
+
Parameters
|
| 114 |
+
----------
|
| 115 |
+
model : :class:`~gensim.models.basemodel.BaseTopicModel`, optional
|
| 116 |
+
Pre-trained topic model. Should be provided if `topics` is not provided.
|
| 117 |
+
Supports models that implement the `get_topics` method.
|
| 118 |
+
topics : list of list of str, optional
|
| 119 |
+
List of tokenized topics. If provided, `dictionary` must also be provided.
|
| 120 |
+
texts : list of list of str, optional
|
| 121 |
+
Tokenized texts, needed for coherence models that use sliding window based (e.g., `c_v`, `c_uci`, `c_npmi`).
|
| 122 |
+
corpus : iterable of list of (int, number), optional
|
| 123 |
+
Corpus in Bag-of-Words format.
|
| 124 |
+
dictionary : :class:`~gensim.corpora.dictionary.Dictionary`, optional
|
| 125 |
+
Gensim dictionary mapping of id word to create corpus.
|
| 126 |
+
If `model.id2word` is present and `dictionary` is None, `model.id2word` will be used.
|
| 127 |
+
window_size : int, optional
|
| 128 |
+
The size of the window to be used for coherence measures using boolean sliding window as their
|
| 129 |
+
probability estimator. For 'u_mass' this doesn't matter.
|
| 130 |
+
If None, default window sizes from `SLIDING_WINDOW_SIZES` are used.
|
| 131 |
+
keyed_vectors : :class:`~gensim.models.keyedvectors.KeyedVectors`, optional
|
| 132 |
+
Pre-trained word embeddings (e.g., Word2Vec model) for 'c_w2v' coherence.
|
| 133 |
+
coherence : {'u_mass', 'c_v', 'c_uci', 'c_npmi', 'c_w2v'}, optional
|
| 134 |
+
Coherence measure to be used.
|
| 135 |
+
'u_mass' requires `corpus` (or `texts` which will be converted to corpus).
|
| 136 |
+
'c_v', 'c_uci', 'c_npmi', 'c_w2v' require `texts`.
|
| 137 |
+
topn : int, optional
|
| 138 |
+
Integer corresponding to the number of top words to be extracted from each topic. Defaults to 20.
|
| 139 |
+
processes : int, optional
|
| 140 |
+
Number of processes to use for probability estimation phase. Any value less than 1 will be interpreted as
|
| 141 |
+
`num_cpus - 1`. Defaults to -1.
|
| 142 |
+
"""
|
| 143 |
+
# Ensure either a model or explicit topics are provided
|
| 144 |
+
if model is None and topics is None:
|
| 145 |
+
raise ValueError("One of 'model' or 'topics' has to be provided.")
|
| 146 |
+
# If topics are provided, a dictionary is mandatory to convert tokens to IDs
|
| 147 |
+
elif topics is not None and dictionary is None:
|
| 148 |
+
raise ValueError("Dictionary has to be provided if 'topics' are to be used.")
|
| 149 |
+
|
| 150 |
+
self.keyed_vectors = keyed_vectors
|
| 151 |
+
# Ensure a data source (keyed_vectors, texts, or corpus) is provided for coherence calculation
|
| 152 |
+
if keyed_vectors is None and texts is None and corpus is None:
|
| 153 |
+
raise ValueError("One of 'texts', 'corpus', or 'keyed_vectors' has to be provided.")
|
| 154 |
+
|
| 155 |
+
# Determine the dictionary to use
|
| 156 |
+
if dictionary is None:
|
| 157 |
+
# If no explicit dictionary, try to use the model's dictionary
|
| 158 |
+
if isinstance(model.id2word, utils.FakeDict):
|
| 159 |
+
# If model's id2word is a FakeDict, it means no proper dictionary is associated
|
| 160 |
+
raise ValueError(
|
| 161 |
+
"The associated dictionary should be provided with the corpus or 'id2word'"
|
| 162 |
+
" for topic model should be set as the associated dictionary.")
|
| 163 |
+
else:
|
| 164 |
+
self.dictionary = model.id2word
|
| 165 |
+
else:
|
| 166 |
+
self.dictionary = dictionary
|
| 167 |
+
|
| 168 |
+
# Store coherence type and window size
|
| 169 |
+
self.coherence = coherence
|
| 170 |
+
self.window_size = window_size
|
| 171 |
+
if self.window_size is None:
|
| 172 |
+
# Use default window size if not specified
|
| 173 |
+
self.window_size = SLIDING_WINDOW_SIZES[self.coherence]
|
| 174 |
+
|
| 175 |
+
# Store texts and corpus
|
| 176 |
+
self.texts = texts
|
| 177 |
+
self.corpus = corpus
|
| 178 |
+
|
| 179 |
+
# Validate inputs based on coherence type
|
| 180 |
+
if coherence in BOOLEAN_DOCUMENT_BASED:
|
| 181 |
+
# For document-based measures (e.g., u_mass), corpus is preferred
|
| 182 |
+
if utils.is_corpus(corpus)[0]:
|
| 183 |
+
self.corpus = corpus
|
| 184 |
+
elif self.texts is not None:
|
| 185 |
+
# If texts are provided, convert them to corpus format
|
| 186 |
+
self.corpus = [self.dictionary.doc2bow(text) for text in self.texts]
|
| 187 |
+
else:
|
| 188 |
+
raise ValueError(
|
| 189 |
+
"Either 'corpus' with 'dictionary' or 'texts' should "
|
| 190 |
+
"be provided for %s coherence." % coherence)
|
| 191 |
+
|
| 192 |
+
elif coherence == 'c_w2v' and keyed_vectors is not None:
|
| 193 |
+
# For c_w2v, keyed_vectors are needed
|
| 194 |
+
pass
|
| 195 |
+
elif coherence in SLIDING_WINDOW_BASED:
|
| 196 |
+
# For sliding window-based measures, texts are required
|
| 197 |
+
if self.texts is None:
|
| 198 |
+
raise ValueError("'texts' should be provided for %s coherence." % coherence)
|
| 199 |
+
else:
|
| 200 |
+
# Raise error if coherence type is not supported
|
| 201 |
+
raise ValueError("%s coherence is not currently supported." % coherence)
|
| 202 |
+
|
| 203 |
+
self._topn = topn
|
| 204 |
+
self._model = model
|
| 205 |
+
self._accumulator = None # Cached accumulator for probability estimation
|
| 206 |
+
self._topics = None # Store topics internally
|
| 207 |
+
self.topics = topics # Call the setter to initialize topics and accumulator state
|
| 208 |
+
|
| 209 |
+
# Determine the number of processes to use for parallelization
|
| 210 |
+
self.processes = processes if processes >= 1 else max(1, mp.cpu_count() - 1)
|
| 211 |
+
|
| 212 |
+
@classmethod
|
| 213 |
+
def for_models(cls, models, dictionary, topn=20, **kwargs):
|
| 214 |
+
"""
|
| 215 |
+
Initialize a CoherenceModel_ttc with estimated probabilities for all of the given models.
|
| 216 |
+
This method extracts topics from each model and then uses `for_topics`.
|
| 217 |
+
|
| 218 |
+
Parameters
|
| 219 |
+
----------
|
| 220 |
+
models : list of :class:`~gensim.models.basemodel.BaseTopicModel`
|
| 221 |
+
List of models to evaluate coherence of. Each model should implement
|
| 222 |
+
the `get_topics` method.
|
| 223 |
+
dictionary : :class:`~gensim.corpora.dictionary.Dictionary`
|
| 224 |
+
Gensim dictionary mapping of id word.
|
| 225 |
+
topn : int, optional
|
| 226 |
+
Integer corresponding to the number of top words to be extracted from each topic. Defaults to 20.
|
| 227 |
+
kwargs : object
|
| 228 |
+
Additional arguments passed to the `CoherenceModel_ttc` constructor (e.g., `corpus`, `texts`, `coherence`).
|
| 229 |
+
|
| 230 |
+
Returns
|
| 231 |
+
-------
|
| 232 |
+
:class:`~gensim.models.coherencemodel.CoherenceModel`
|
| 233 |
+
CoherenceModel_ttc instance with estimated probabilities for all given models.
|
| 234 |
+
|
| 235 |
+
Example
|
| 236 |
+
-------
|
| 237 |
+
.. sourcecode:: pycon
|
| 238 |
+
|
| 239 |
+
>>> from gensim.test.utils import common_corpus, common_dictionary
|
| 240 |
+
>>> from gensim.models.ldamodel import LdaModel
|
| 241 |
+
>>> # from your_module import CoherenceModel_ttc
|
| 242 |
+
>>>
|
| 243 |
+
>>> m1 = LdaModel(common_corpus, 3, common_dictionary)
|
| 244 |
+
>>> m2 = LdaModel(common_corpus, 5, common_dictionary)
|
| 245 |
+
>>>
|
| 246 |
+
>>> cm = CoherenceModel_ttc.for_models([m1, m2], common_dictionary, corpus=common_corpus, coherence='u_mass')
|
| 247 |
+
>>> # To get coherences for each model:
|
| 248 |
+
>>> # model_coherences = cm.compare_model_topics([
|
| 249 |
+
>>> # CoherenceModel_ttc._get_topics_from_model(m1, topn=cm.topn),
|
| 250 |
+
>>> # CoherenceModel_ttc._get_topics_from_model(m2, topn=cm.topn)
|
| 251 |
+
>>> # ])
|
| 252 |
+
"""
|
| 253 |
+
# Extract top words as lists for each model's topics
|
| 254 |
+
topics = [cls.top_topics_as_word_lists(model, dictionary, topn) for model in models]
|
| 255 |
+
kwargs['dictionary'] = dictionary
|
| 256 |
+
kwargs['topn'] = topn
|
| 257 |
+
# Use for_topics to initialize the coherence model with these topics
|
| 258 |
+
return cls.for_topics(topics, **kwargs)
|
| 259 |
+
|
| 260 |
+
@staticmethod
|
| 261 |
+
def top_topics_as_word_lists(model, dictionary, topn=20):
|
| 262 |
+
"""
|
| 263 |
+
Get `topn` topics from a model as lists of words.
|
| 264 |
+
|
| 265 |
+
Parameters
|
| 266 |
+
----------
|
| 267 |
+
model : :class:`~gensim.models.basemodel.BaseTopicModel`
|
| 268 |
+
Pre-trained topic model.
|
| 269 |
+
dictionary : :class:`~gensim.corpora.dictionary.Dictionary`
|
| 270 |
+
Gensim dictionary mapping of id word.
|
| 271 |
+
topn : int, optional
|
| 272 |
+
Integer corresponding to the number of top words to be extracted from each topic. Defaults to 20.
|
| 273 |
+
|
| 274 |
+
Returns
|
| 275 |
+
-------
|
| 276 |
+
list of list of str
|
| 277 |
+
Top topics in list-of-list-of-words format.
|
| 278 |
+
"""
|
| 279 |
+
# Ensure id2token mapping exists in the dictionary
|
| 280 |
+
if not dictionary.id2token:
|
| 281 |
+
dictionary.id2token = {v: k for k, v in dictionary.token2id.items()}
|
| 282 |
+
|
| 283 |
+
str_topics = []
|
| 284 |
+
for topic_distribution in model.get_topics():
|
| 285 |
+
# Get the indices of the topN words based on their probabilities
|
| 286 |
+
bestn_indices = matutils.argsort(topic_distribution, topn=topn, reverse=True)
|
| 287 |
+
# Convert word IDs back to words using the dictionary
|
| 288 |
+
best_words = [dictionary.id2token[_id] for _id in bestn_indices]
|
| 289 |
+
str_topics.append(best_words)
|
| 290 |
+
return str_topics
|
| 291 |
+
|
| 292 |
+
@classmethod
|
| 293 |
+
def for_topics(cls, topics_as_topn_terms, **kwargs):
|
| 294 |
+
"""
|
| 295 |
+
Initialize a CoherenceModel_ttc with estimated probabilities for all of the given topics.
|
| 296 |
+
This is useful when you have raw topics (list of lists of words) and not a Gensim model object.
|
| 297 |
+
|
| 298 |
+
Parameters
|
| 299 |
+
----------
|
| 300 |
+
topics_as_topn_terms : list of list of str
|
| 301 |
+
Each element in the top-level list should be a list of top-N words, one per topic.
|
| 302 |
+
For example: `[['word1', 'word2'], ['word3', 'word4']]`.
|
| 303 |
+
|
| 304 |
+
Returns
|
| 305 |
+
-------
|
| 306 |
+
:class:`~gensim.models.coherencemodel.CoherenceModel`
|
| 307 |
+
CoherenceModel_ttc with estimated probabilities for the given topics.
|
| 308 |
+
"""
|
| 309 |
+
if not topics_as_topn_terms:
|
| 310 |
+
raise ValueError("len(topics_as_topn_terms) must be > 0.")
|
| 311 |
+
if any(len(topic_list) == 0 for topic_list in topics_as_topn_terms):
|
| 312 |
+
raise ValueError("Found an empty topic listing in `topics_as_topn_terms`.")
|
| 313 |
+
|
| 314 |
+
# Determine the maximum 'topn' value among the provided topics
|
| 315 |
+
# This will be used to initialize the CoherenceModel_ttc correctly for probability estimation
|
| 316 |
+
actual_topn_in_data = 0
|
| 317 |
+
for topic_list in topics_as_topn_terms:
|
| 318 |
+
for topic in topic_list:
|
| 319 |
+
actual_topn_in_data = max(actual_topn_in_data, len(topic))
|
| 320 |
+
|
| 321 |
+
# Use the provided 'topn' from kwargs, or the determined 'actual_topn_in_data',
|
| 322 |
+
# ensuring it's not greater than the actual data available.
|
| 323 |
+
# This allows for precomputing probabilities for a wider set of words if needed.
|
| 324 |
+
topn_for_prob_estimation = min(kwargs.pop('topn', actual_topn_in_data), actual_topn_in_data)
|
| 325 |
+
|
| 326 |
+
# Flatten all topics into a single "super topic" for initial probability estimation.
|
| 327 |
+
# This ensures that all words relevant to *any* topic in the comparison set
|
| 328 |
+
# are included in the accumulator.
|
| 329 |
+
super_topic = utils.flatten(topics_as_topn_terms)
|
| 330 |
+
|
| 331 |
+
logger.info(
|
| 332 |
+
"Number of relevant terms for all %d models (or topic sets): %d",
|
| 333 |
+
len(topics_as_topn_terms), len(super_topic))
|
| 334 |
+
|
| 335 |
+
# Initialize CoherenceModel_ttc with the super topic to pre-estimate probabilities
|
| 336 |
+
# for all relevant words across all models.
|
| 337 |
+
# We pass `topics=[super_topic]` and `topn=len(super_topic)` to ensure all words
|
| 338 |
+
# are considered during the probability estimation phase.
|
| 339 |
+
cm = CoherenceModel_ttc(topics=[super_topic], topn=len(super_topic), **kwargs)
|
| 340 |
+
cm.estimate_probabilities() # Perform the actual probability estimation
|
| 341 |
+
|
| 342 |
+
# After estimation, set the 'topn' back to the desired value for coherence calculation.
|
| 343 |
+
cm.topn = topn_for_prob_estimation
|
| 344 |
+
return cm
|
| 345 |
+
|
| 346 |
+
def __str__(self):
|
| 347 |
+
"""Returns a string representation of the coherence measure pipeline."""
|
| 348 |
+
return str(self.measure)
|
| 349 |
+
|
| 350 |
+
@property
|
| 351 |
+
def model(self):
|
| 352 |
+
"""
|
| 353 |
+
Get the current topic model used by the instance.
|
| 354 |
+
|
| 355 |
+
Returns
|
| 356 |
+
-------
|
| 357 |
+
:class:`~gensim.models.basemodel.BaseTopicModel`
|
| 358 |
+
The currently set topic model.
|
| 359 |
+
"""
|
| 360 |
+
return self._model
|
| 361 |
+
|
| 362 |
+
@model.setter
|
| 363 |
+
def model(self, model):
|
| 364 |
+
"""
|
| 365 |
+
Set the topic model for the instance. When a new model is set,
|
| 366 |
+
it triggers an update of the internal topics and checks if the accumulator needs recomputing.
|
| 367 |
+
|
| 368 |
+
Parameters
|
| 369 |
+
----------
|
| 370 |
+
model : :class:`~gensim.models.basemodel.BaseTopicModel`
|
| 371 |
+
The new topic model to set.
|
| 372 |
+
"""
|
| 373 |
+
self._model = model
|
| 374 |
+
if model is not None:
|
| 375 |
+
new_topics = self._get_topics() # Get topics from the new model
|
| 376 |
+
self._update_accumulator(new_topics) # Check and update accumulator if needed
|
| 377 |
+
self._topics = new_topics # Store the new topics
|
| 378 |
+
|
| 379 |
+
@property
|
| 380 |
+
def topn(self):
|
| 381 |
+
"""
|
| 382 |
+
Get the number of top words (`_topn`) used for coherence calculation.
|
| 383 |
+
|
| 384 |
+
Returns
|
| 385 |
+
-------
|
| 386 |
+
int
|
| 387 |
+
The number of top words.
|
| 388 |
+
"""
|
| 389 |
+
return self._topn
|
| 390 |
+
|
| 391 |
+
@topn.setter
|
| 392 |
+
def topn(self, topn):
|
| 393 |
+
"""
|
| 394 |
+
Set the number of top words (`_topn`) to consider for coherence calculation.
|
| 395 |
+
If the new `topn` requires more words than currently loaded topics, and a model is available,
|
| 396 |
+
it will attempt to re-extract topics from the model.
|
| 397 |
+
|
| 398 |
+
Parameters
|
| 399 |
+
----------
|
| 400 |
+
topn : int
|
| 401 |
+
The new number of top words.
|
| 402 |
+
"""
|
| 403 |
+
# Get the length of the first topic to check current topic length
|
| 404 |
+
current_topic_length = len(self._topics[0])
|
| 405 |
+
# Determine if the new 'topn' requires more words than currently available in topics
|
| 406 |
+
requires_expansion = current_topic_length < topn
|
| 407 |
+
|
| 408 |
+
if self.model is not None:
|
| 409 |
+
self._topn = topn
|
| 410 |
+
if requires_expansion:
|
| 411 |
+
# If expansion is needed and a model is available, re-extract topics from the model.
|
| 412 |
+
# This call to the setter property `self.model = self._model` effectively re-runs
|
| 413 |
+
# the logic that extracts topics and updates the accumulator based on the new `_topn`.
|
| 414 |
+
self.model = self._model
|
| 415 |
+
else:
|
| 416 |
+
# If no model is available and expansion is required, raise an error
|
| 417 |
+
if requires_expansion:
|
| 418 |
+
raise ValueError("Model unavailable and topic sizes are less than topn=%d" % topn)
|
| 419 |
+
self._topn = topn # Topics will be truncated by the `topics` getter if needed
|
| 420 |
+
|
| 421 |
+
@property
|
| 422 |
+
def measure(self):
|
| 423 |
+
"""
|
| 424 |
+
Returns the namedtuple representing the coherence pipeline functions
|
| 425 |
+
(segmentation, probability estimation, confirmation, aggregation)
|
| 426 |
+
based on the `self.coherence` type.
|
| 427 |
+
|
| 428 |
+
Returns
|
| 429 |
+
-------
|
| 430 |
+
namedtuple
|
| 431 |
+
Pipeline that contains needed functions/method for calculating coherence.
|
| 432 |
+
"""
|
| 433 |
+
return COHERENCE_MEASURES[self.coherence]
|
| 434 |
+
|
| 435 |
+
@property
|
| 436 |
+
def topics(self):
|
| 437 |
+
"""
|
| 438 |
+
Get the current topics. If the internally stored topics have more words
|
| 439 |
+
than `self._topn`, they are truncated to `self._topn` words.
|
| 440 |
+
|
| 441 |
+
Returns
|
| 442 |
+
-------
|
| 443 |
+
list of list of str
|
| 444 |
+
Topics as lists of word tokens.
|
| 445 |
+
"""
|
| 446 |
+
# If the stored topics contain more words than `_topn`, truncate them
|
| 447 |
+
if len(self._topics[0]) > self._topn:
|
| 448 |
+
return [topic[:self._topn] for topic in self._topics]
|
| 449 |
+
else:
|
| 450 |
+
return self._topics
|
| 451 |
+
|
| 452 |
+
@topics.setter
|
| 453 |
+
def topics(self, topics):
|
| 454 |
+
"""
|
| 455 |
+
Set the topics for the instance. This method converts topic words to their
|
| 456 |
+
corresponding dictionary IDs and updates the accumulator state.
|
| 457 |
+
|
| 458 |
+
Parameters
|
| 459 |
+
----------
|
| 460 |
+
topics : list of list of str or list of list of int
|
| 461 |
+
Topics, either as lists of word tokens or lists of word IDs.
|
| 462 |
+
"""
|
| 463 |
+
if topics is not None:
|
| 464 |
+
new_topics = []
|
| 465 |
+
for topic in topics:
|
| 466 |
+
# Ensure topic elements are converted to dictionary IDs (numpy array for efficiency)
|
| 467 |
+
topic_token_ids = self._ensure_elements_are_ids(topic)
|
| 468 |
+
new_topics.append(topic_token_ids)
|
| 469 |
+
|
| 470 |
+
if self.model is not None:
|
| 471 |
+
# Warn if both model and explicit topics are set, as they might be inconsistent
|
| 472 |
+
logger.warning(
|
| 473 |
+
"The currently set model '%s' may be inconsistent with the newly set topics",
|
| 474 |
+
self.model)
|
| 475 |
+
elif self.model is not None:
|
| 476 |
+
# If topics are None but a model exists, extract topics from the model
|
| 477 |
+
new_topics = self._get_topics()
|
| 478 |
+
logger.debug("Setting topics to those of the model: %s", self.model)
|
| 479 |
+
else:
|
| 480 |
+
new_topics = None
|
| 481 |
+
|
| 482 |
+
# Check if the accumulator needs to be recomputed based on the new topics
|
| 483 |
+
self._update_accumulator(new_topics)
|
| 484 |
+
self._topics = new_topics # Store the (ID-converted) topics
|
| 485 |
+
|
| 486 |
+
def _ensure_elements_are_ids(self, topic):
|
| 487 |
+
"""
|
| 488 |
+
Internal helper to ensure that topic elements are converted to dictionary IDs.
|
| 489 |
+
Handles cases where input topic might be tokens or already IDs.
|
| 490 |
+
|
| 491 |
+
Parameters
|
| 492 |
+
----------
|
| 493 |
+
topic : list of str or list of int
|
| 494 |
+
A single topic, either as a list of word tokens or word IDs.
|
| 495 |
+
|
| 496 |
+
Returns
|
| 497 |
+
-------
|
| 498 |
+
:class:`numpy.ndarray`
|
| 499 |
+
A numpy array of word IDs for the topic.
|
| 500 |
+
|
| 501 |
+
Raises
|
| 502 |
+
------
|
| 503 |
+
KeyError
|
| 504 |
+
If a token is not found in the dictionary or an ID is not a valid key in id2token.
|
| 505 |
+
"""
|
| 506 |
+
try:
|
| 507 |
+
# Try to convert tokens to IDs. This is the common case if `topic` contains strings.
|
| 508 |
+
return np.array([self.dictionary.token2id[token] for token in topic if token in self.dictionary.token2id])
|
| 509 |
+
except KeyError:
|
| 510 |
+
# If `KeyError` occurs, assume `topic` might already be a list of IDs.
|
| 511 |
+
# Attempt to convert IDs to tokens and then back to IDs, ensuring they are valid dictionary entries.
|
| 512 |
+
# This handles cases where `topic` might contain integer IDs that are not present in the dictionary.
|
| 513 |
+
try:
|
| 514 |
+
# Convert IDs to tokens (via id2token) and then tokens to IDs (via token2id)
|
| 515 |
+
# This filters out invalid IDs.
|
| 516 |
+
return np.array([self.dictionary.token2id[self.dictionary.id2token[_id]]
|
| 517 |
+
for _id in topic if _id in self.dictionary])
|
| 518 |
+
except KeyError:
|
| 519 |
+
raise ValueError("Unable to interpret topic as either a list of tokens or a list of valid IDs within the dictionary.")
|
| 520 |
+
|
| 521 |
+
def _update_accumulator(self, new_topics):
|
| 522 |
+
"""
|
| 523 |
+
Internal helper to determine if the cached `_accumulator` (probability statistics)
|
| 524 |
+
needs to be wiped and recomputed due to changes in topics.
|
| 525 |
+
"""
|
| 526 |
+
if self._relevant_ids_will_differ(new_topics):
|
| 527 |
+
logger.debug("Wiping cached accumulator since it does not contain all relevant ids.")
|
| 528 |
+
self._accumulator = None
|
| 529 |
+
|
| 530 |
+
def _relevant_ids_will_differ(self, new_topics):
|
| 531 |
+
"""
|
| 532 |
+
Internal helper to check if the set of unique word IDs relevant to the new topics
|
| 533 |
+
is different from the IDs already covered by the current accumulator.
|
| 534 |
+
|
| 535 |
+
Parameters
|
| 536 |
+
----------
|
| 537 |
+
new_topics : list of list of int
|
| 538 |
+
The new set of topics (as word IDs).
|
| 539 |
+
|
| 540 |
+
Returns
|
| 541 |
+
-------
|
| 542 |
+
bool
|
| 543 |
+
True if the relevant IDs will differ, False otherwise.
|
| 544 |
+
"""
|
| 545 |
+
if self._accumulator is None or not self._topics_differ(new_topics):
|
| 546 |
+
return False
|
| 547 |
+
|
| 548 |
+
# Get unique IDs from the segmented new topics
|
| 549 |
+
new_set = unique_ids_from_segments(self.measure.seg(new_topics))
|
| 550 |
+
# Check if the current accumulator's relevant IDs are a superset of the new set.
|
| 551 |
+
# If not, it means the new topics introduce words not covered, so the accumulator needs updating.
|
| 552 |
+
return not self._accumulator.relevant_ids.issuperset(new_set)
|
| 553 |
+
|
| 554 |
+
def _topics_differ(self, new_topics):
|
| 555 |
+
"""
|
| 556 |
+
Internal helper to check if the new topics are different from the currently stored topics.
|
| 557 |
+
|
| 558 |
+
Parameters
|
| 559 |
+
----------
|
| 560 |
+
new_topics : list of list of int
|
| 561 |
+
The new set of topics (as word IDs).
|
| 562 |
+
|
| 563 |
+
Returns
|
| 564 |
+
-------
|
| 565 |
+
bool
|
| 566 |
+
True if topics are different, False otherwise.
|
| 567 |
+
"""
|
| 568 |
+
# Compare topic arrays using numpy.array_equal for efficient comparison
|
| 569 |
+
return (new_topics is not None
|
| 570 |
+
and self._topics is not None
|
| 571 |
+
and not np.array_equal(new_topics, self._topics))
|
| 572 |
+
|
| 573 |
+
def _get_topics(self):
|
| 574 |
+
"""
|
| 575 |
+
Internal helper function to extract top words (as IDs) from a trained topic model.
|
| 576 |
+
"""
|
| 577 |
+
return self._get_topics_from_model(self.model, self.topn)
|
| 578 |
+
|
| 579 |
+
@staticmethod
|
| 580 |
+
def _get_topics_from_model(model, topn):
|
| 581 |
+
"""
|
| 582 |
+
Internal static method to extract top `topn` words (as IDs) from a trained topic model.
|
| 583 |
+
|
| 584 |
+
Parameters
|
| 585 |
+
----------
|
| 586 |
+
model : :class:`~gensim.models.basemodel.BaseTopicModel`
|
| 587 |
+
Pre-trained topic model (must implement `get_topics` method).
|
| 588 |
+
topn : int
|
| 589 |
+
Integer corresponding to the number of top words to extract.
|
| 590 |
+
|
| 591 |
+
Returns
|
| 592 |
+
-------
|
| 593 |
+
list of :class:`numpy.ndarray`
|
| 594 |
+
A list where each element is a numpy array of word IDs representing a topic's top words.
|
| 595 |
+
|
| 596 |
+
Raises
|
| 597 |
+
------
|
| 598 |
+
AttributeError
|
| 599 |
+
If the provided model does not implement a `get_topics` method.
|
| 600 |
+
"""
|
| 601 |
+
try:
|
| 602 |
+
# Iterate over the topic distributions from the model
|
| 603 |
+
# Use matutils.argsort to get the indices (word IDs) of the top `topn` words
|
| 604 |
+
return [
|
| 605 |
+
matutils.argsort(topic, topn=topn, reverse=True) for topic in
|
| 606 |
+
model.get_topics()
|
| 607 |
+
]
|
| 608 |
+
except AttributeError:
|
| 609 |
+
raise ValueError(
|
| 610 |
+
"This topic model is not currently supported. Supported topic models"
|
| 611 |
+
" should implement the `get_topics` method.")
|
| 612 |
+
|
| 613 |
+
def segment_topics(self):
|
| 614 |
+
"""
|
| 615 |
+
Segments the current topics using the segmentation function defined by the
|
| 616 |
+
chosen coherence measure (`self.measure.seg`).
|
| 617 |
+
|
| 618 |
+
Returns
|
| 619 |
+
-------
|
| 620 |
+
list of list of tuple
|
| 621 |
+
Segmented topics. The structure depends on the segmentation method (e.g., pairs of word IDs).
|
| 622 |
+
"""
|
| 623 |
+
# Apply the segmentation function from the pipeline to the current topics
|
| 624 |
+
return self.measure.seg(self.topics)
|
| 625 |
+
|
| 626 |
+
def estimate_probabilities(self, segmented_topics=None):
|
| 627 |
+
"""
|
| 628 |
+
Accumulates word occurrences and co-occurrences from texts or corpus
|
| 629 |
+
using the optimal probability estimation method for the chosen coherence metric.
|
| 630 |
+
This operation can be computationally intensive, especially for sliding window methods.
|
| 631 |
+
|
| 632 |
+
Parameters
|
| 633 |
+
----------
|
| 634 |
+
segmented_topics : list of list of tuple, optional
|
| 635 |
+
Segmented topics. If None, `self.segment_topics()` is called internally.
|
| 636 |
+
|
| 637 |
+
Returns
|
| 638 |
+
-------
|
| 639 |
+
:class:`~gensim.topic_coherence.text_analysis.CorpusAccumulator`
|
| 640 |
+
An object that holds the accumulated statistics (word frequencies, co-occurrence frequencies).
|
| 641 |
+
"""
|
| 642 |
+
if segmented_topics is None:
|
| 643 |
+
segmented_topics = self.segment_topics()
|
| 644 |
+
|
| 645 |
+
# Choose the appropriate probability estimation method based on the coherence type
|
| 646 |
+
if self.coherence in BOOLEAN_DOCUMENT_BASED:
|
| 647 |
+
self._accumulator = self.measure.prob(self.corpus, segmented_topics)
|
| 648 |
+
else:
|
| 649 |
+
kwargs = dict(
|
| 650 |
+
texts=self.texts, segmented_topics=segmented_topics,
|
| 651 |
+
dictionary=self.dictionary, window_size=self.window_size,
|
| 652 |
+
processes=self.processes)
|
| 653 |
+
if self.coherence == 'c_w2v':
|
| 654 |
+
kwargs['model'] = self.keyed_vectors # Pass keyed_vectors for word2vec based coherence
|
| 655 |
+
|
| 656 |
+
self._accumulator = self.measure.prob(**kwargs)
|
| 657 |
+
|
| 658 |
+
return self._accumulator
|
| 659 |
+
|
| 660 |
+
def get_coherence_per_topic(self, segmented_topics=None, with_std=False, with_support=False):
|
| 661 |
+
"""
|
| 662 |
+
Calculates and returns a list of coherence values, one for each topic,
|
| 663 |
+
based on the pipeline's confirmation measure.
|
| 664 |
+
|
| 665 |
+
Parameters
|
| 666 |
+
----------
|
| 667 |
+
segmented_topics : list of list of tuple, optional
|
| 668 |
+
Segmented topics. If None, `self.segment_topics()` is called internally.
|
| 669 |
+
with_std : bool, optional
|
| 670 |
+
If True, also includes the standard deviation across topic segment sets in addition
|
| 671 |
+
to the mean coherence for each topic. Defaults to False.
|
| 672 |
+
with_support : bool, optional
|
| 673 |
+
If True, also includes the "support" (number of pairwise similarity comparisons)
|
| 674 |
+
used to compute each topic's coherence. Defaults to False.
|
| 675 |
+
|
| 676 |
+
Returns
|
| 677 |
+
-------
|
| 678 |
+
list of float or list of tuple
|
| 679 |
+
A sequence of similarity measures for each topic.
|
| 680 |
+
If `with_std` or `with_support` is True, each element in the list will be a tuple
|
| 681 |
+
containing the coherence value and the requested additional statistics.
|
| 682 |
+
"""
|
| 683 |
+
measure = self.measure
|
| 684 |
+
if segmented_topics is None:
|
| 685 |
+
segmented_topics = measure.seg(self.topics)
|
| 686 |
+
|
| 687 |
+
# Ensure probabilities are estimated before calculating coherence
|
| 688 |
+
if self._accumulator is None:
|
| 689 |
+
self.estimate_probabilities(segmented_topics)
|
| 690 |
+
|
| 691 |
+
kwargs = dict(with_std=with_std, with_support=with_support)
|
| 692 |
+
if self.coherence in BOOLEAN_DOCUMENT_BASED or self.coherence == 'c_w2v':
|
| 693 |
+
# These coherence types don't require specific additional kwargs for confirmation measure
|
| 694 |
+
pass
|
| 695 |
+
elif self.coherence == 'c_v':
|
| 696 |
+
# Specific kwargs for c_v's confirmation measure (cosine_similarity)
|
| 697 |
+
kwargs['topics'] = self.topics
|
| 698 |
+
kwargs['measure'] = 'nlr' # Normalized Log Ratio
|
| 699 |
+
kwargs['gamma'] = 1
|
| 700 |
+
else:
|
| 701 |
+
# For c_uci and c_npmi, 'normalize' parameter is relevant
|
| 702 |
+
kwargs['normalize'] = (self.coherence == 'c_npmi')
|
| 703 |
+
|
| 704 |
+
return measure.conf(segmented_topics, self._accumulator, **kwargs)
|
| 705 |
+
|
| 706 |
+
def aggregate_measures(self, topic_coherences):
|
| 707 |
+
"""
|
| 708 |
+
Aggregates the individual topic coherence measures into a single overall score
|
| 709 |
+
using the pipeline's aggregation function (`self.measure.aggr`).
|
| 710 |
+
|
| 711 |
+
Parameters
|
| 712 |
+
----------
|
| 713 |
+
topic_coherences : list of float
|
| 714 |
+
List of coherence values for each topic.
|
| 715 |
+
|
| 716 |
+
Returns
|
| 717 |
+
-------
|
| 718 |
+
float
|
| 719 |
+
The aggregated coherence value (e.g., arithmetic mean).
|
| 720 |
+
"""
|
| 721 |
+
# Apply the aggregation function from the pipeline to the list of topic coherences
|
| 722 |
+
return self.measure.aggr(topic_coherences)
|
| 723 |
+
|
| 724 |
+
def get_coherence(self):
|
| 725 |
+
"""
|
| 726 |
+
Calculates and returns the overall coherence value for the entire set of topics.
|
| 727 |
+
This is the main entry point for getting a single coherence score.
|
| 728 |
+
|
| 729 |
+
Returns
|
| 730 |
+
-------
|
| 731 |
+
float
|
| 732 |
+
The aggregated coherence value.
|
| 733 |
+
"""
|
| 734 |
+
# First, get coherence values for each individual topic
|
| 735 |
+
confirmed_measures = self.get_coherence_per_topic()
|
| 736 |
+
# Then, aggregate these topic-level coherences into a single score
|
| 737 |
+
return self.aggregate_measures(confirmed_measures)
|
| 738 |
+
|
| 739 |
+
def compare_models(self, models):
|
| 740 |
+
"""
|
| 741 |
+
Compares multiple topic models by their coherence values.
|
| 742 |
+
It extracts topics from each model and then calls `compare_model_topics`.
|
| 743 |
+
|
| 744 |
+
Parameters
|
| 745 |
+
----------
|
| 746 |
+
models : list of :class:`~gensim.models.basemodel.BaseTopicModel`
|
| 747 |
+
A sequence of topic models to compare.
|
| 748 |
+
|
| 749 |
+
Returns
|
| 750 |
+
-------
|
| 751 |
+
list of (list of float, float)
|
| 752 |
+
A sequence where each element is a pair:
|
| 753 |
+
(list of average topic coherences for the model, overall model coherence).
|
| 754 |
+
"""
|
| 755 |
+
# Extract topics (as word IDs) for each model using the internal helper
|
| 756 |
+
model_topics = [self._get_topics_from_model(model, self.topn) for model in models]
|
| 757 |
+
# Delegate to compare_model_topics for the actual coherence comparison
|
| 758 |
+
return self.compare_model_topics(model_topics)
|
| 759 |
+
|
| 760 |
+
def compare_model_topics(self, model_topics):
|
| 761 |
+
"""
|
| 762 |
+
Performs coherence evaluation for each set of topics provided in `model_topics`.
|
| 763 |
+
This method is designed to be efficient by precomputing probabilities once if needed,
|
| 764 |
+
and then evaluating coherence for each set of topics.
|
| 765 |
+
|
| 766 |
+
Parameters
|
| 767 |
+
----------
|
| 768 |
+
model_topics : list of list of list of int
|
| 769 |
+
A list where each element is itself a list of topics (each topic being a list of word IDs)
|
| 770 |
+
representing a set of topics (e.g., from a single model).
|
| 771 |
+
|
| 772 |
+
Returns
|
| 773 |
+
-------
|
| 774 |
+
list of (list of float, float)
|
| 775 |
+
A sequence where each element is a pair:
|
| 776 |
+
(list of average topic coherences for the topic set, overall topic set coherence).
|
| 777 |
+
|
| 778 |
+
Notes
|
| 779 |
+
-----
|
| 780 |
+
This method uses a heuristic of evaluating coherence at various `topn` values (e.g., 20, 15, 10, 5)
|
| 781 |
+
and averaging the results for robustness, as suggested in some research.
|
| 782 |
+
"""
|
| 783 |
+
# Store original topics and topn to restore them after comparison
|
| 784 |
+
orig_topics = self._topics
|
| 785 |
+
orig_topn = self.topn
|
| 786 |
+
|
| 787 |
+
try:
|
| 788 |
+
# Perform the actual comparison
|
| 789 |
+
coherences = self._compare_model_topics(model_topics)
|
| 790 |
+
finally:
|
| 791 |
+
# Ensure original topics and topn are restored even if an error occurs
|
| 792 |
+
self.topics = orig_topics
|
| 793 |
+
self.topn = orig_topn
|
| 794 |
+
|
| 795 |
+
return coherences
|
| 796 |
+
|
| 797 |
+
def _compare_model_topics(self, model_topics):
|
| 798 |
+
"""
|
| 799 |
+
Internal helper to get average topic and model coherences across multiple sets of topics.
|
| 800 |
+
|
| 801 |
+
Parameters
|
| 802 |
+
----------
|
| 803 |
+
model_topics : list of list of list of int
|
| 804 |
+
A list where each element is a set of topics (list of lists of word IDs).
|
| 805 |
+
|
| 806 |
+
Returns
|
| 807 |
+
-------
|
| 808 |
+
list of (list of float, float)
|
| 809 |
+
A sequence of pairs:
|
| 810 |
+
(average topic coherences across different `topn` values for each topic,
|
| 811 |
+
overall model coherence averaged across different `topn` values).
|
| 812 |
+
"""
|
| 813 |
+
coherences = []
|
| 814 |
+
# Define a grid of `topn` values to evaluate coherence.
|
| 815 |
+
# This provides a more robust average coherence value.
|
| 816 |
+
# It goes from `self.topn` down to `min(self.topn - 1, 4)` in steps of -5.
|
| 817 |
+
# e.g., if self.topn is 20, grid might be [20, 15, 10, 5].
|
| 818 |
+
# The `min(self.topn - 1, 4)` ensures at least some lower values are included,
|
| 819 |
+
# but also prevents trying `topn` values that are too small or negative.
|
| 820 |
+
last_topn_value = min(self.topn - 1, 4)
|
| 821 |
+
topn_grid = list(range(self.topn, last_topn_value, -5))
|
| 822 |
+
if not topn_grid or max(topn_grid) < 1: # Ensure at least one valid topn if range is empty or too small
|
| 823 |
+
topn_grid = [max(1, min(self.topn, 5))] # Use min of self.topn and 5, ensure at least 1
|
| 824 |
+
|
| 825 |
+
for model_num, topics in enumerate(model_topics):
|
| 826 |
+
# Set the current topics for the instance to the topics of the model being evaluated
|
| 827 |
+
self.topics = topics
|
| 828 |
+
|
| 829 |
+
coherence_at_n = {} # Dictionary to store coherence results for different `topn` values
|
| 830 |
+
for n in topn_grid:
|
| 831 |
+
self.topn = n # Set the `topn` for the current evaluation round
|
| 832 |
+
topic_coherences = self.get_coherence_per_topic()
|
| 833 |
+
|
| 834 |
+
# Handle NaN values in topic coherences by imputing with the mean
|
| 835 |
+
filled_coherences = np.array(topic_coherences, dtype=float)
|
| 836 |
+
# Check for NaN values and replace them with the mean of non-NaN values.
|
| 837 |
+
# np.nanmean handles arrays with all NaNs gracefully by returning NaN.
|
| 838 |
+
if np.any(np.isnan(filled_coherences)):
|
| 839 |
+
mean_val = np.nanmean(filled_coherences)
|
| 840 |
+
if np.isnan(mean_val): # If all are NaN, mean_val will also be NaN. In this case, replace with 0 or a very small number.
|
| 841 |
+
filled_coherences[np.isnan(filled_coherences)] = 0.0 # Or another sensible default
|
| 842 |
+
else:
|
| 843 |
+
filled_coherences[np.isnan(filled_coherences)] = mean_val
|
| 844 |
+
|
| 845 |
+
|
| 846 |
+
# Store the topic-level coherences and the aggregated (overall) coherence for this `topn`
|
| 847 |
+
coherence_at_n[n] = (topic_coherences, self.aggregate_measures(filled_coherences))
|
| 848 |
+
|
| 849 |
+
# Unpack the stored coherences for different `topn` values
|
| 850 |
+
all_topic_coherences_at_n, all_avg_coherences_at_n = zip(*coherence_at_n.values())
|
| 851 |
+
|
| 852 |
+
# Calculate the average topic coherence across all `topn` values
|
| 853 |
+
# np.vstack stacks lists of topic coherences into a 2D array, then mean(0) computes mean for each topic.
|
| 854 |
+
avg_topic_coherences = np.vstack(all_topic_coherences_at_n).mean(axis=0)
|
| 855 |
+
|
| 856 |
+
# Calculate the overall model coherence by averaging the aggregated coherences from all `topn` values
|
| 857 |
+
model_coherence = np.mean(all_avg_coherences_at_n)
|
| 858 |
+
|
| 859 |
+
logging.info("Avg coherence for model %d: %.5f" % (model_num, model_coherence))
|
| 860 |
+
coherences.append((avg_topic_coherences.tolist(), model_coherence)) # Convert numpy array back to list for output
|
| 861 |
+
|
| 862 |
+
return coherences
|
backend/evaluation/eval.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# dynamic_topic_quality.py
|
| 2 |
+
import numpy as np
|
| 3 |
+
import pandas as pd
|
| 4 |
+
from gensim.corpora.dictionary import Dictionary
|
| 5 |
+
from gensim.models.coherencemodel import CoherenceModel
|
| 6 |
+
from backend.evaluation.CoherenceModel_ttc import CoherenceModel_ttc
|
| 7 |
+
from typing import List, Dict
|
| 8 |
+
|
| 9 |
+
class TopicQualityAssessor:
|
| 10 |
+
"""
|
| 11 |
+
Calculates various quality metrics for dynamic topic models from in-memory data.
|
| 12 |
+
|
| 13 |
+
This class provides methods to compute:
|
| 14 |
+
- Temporal Topic Coherence (TTC)
|
| 15 |
+
- Temporal Topic Smoothness (TTS)
|
| 16 |
+
- Temporal Topic Quality (TTQ)
|
| 17 |
+
- Yearly Topic Coherence (TC)
|
| 18 |
+
- Yearly Topic Diversity (TD)
|
| 19 |
+
- Yearly Topic Quality (TQ)
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, topics: List[List[List[str]]], train_texts: List[List[str]], topn: int, coherence_type: str):
|
| 23 |
+
"""
|
| 24 |
+
Initializes the TopicQualityAssessor with data in memory.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
topics (List[List[List[str]]]): A nested list of topics with structure (T, K, W),
|
| 28 |
+
where T is time slices, K is topics, and W is words.
|
| 29 |
+
train_texts (List[List[str]]): A list of tokenized documents for the reference corpus.
|
| 30 |
+
topn (int): Number of top words per topic to consider for calculations.
|
| 31 |
+
coherence_type (str): The type of coherence to calculate (e.g., 'c_npmi', 'c_v').
|
| 32 |
+
"""
|
| 33 |
+
# 1. Set texts and dictionary
|
| 34 |
+
self.texts = train_texts
|
| 35 |
+
self.dictionary = Dictionary(self.texts)
|
| 36 |
+
|
| 37 |
+
# 2. Process topics
|
| 38 |
+
# User provides topics as (T, K, W) -> List[timestamps][topics][words]
|
| 39 |
+
# Internal representation for temporal evolution is (K, T, W)
|
| 40 |
+
topics_array_T_K_W = np.array(topics, dtype=object)
|
| 41 |
+
if topics_array_T_K_W.ndim != 3:
|
| 42 |
+
raise ValueError(f"Input 'topics' must be a 3-dimensional list/array. Got {topics_array_T_K_W.ndim} dimensions.")
|
| 43 |
+
self.total_topics = topics_array_T_K_W.transpose(1, 0, 2) # Shape: (K, T, W)
|
| 44 |
+
|
| 45 |
+
# 3. Get dimensions
|
| 46 |
+
self.K, self.T, _ = self.total_topics.shape
|
| 47 |
+
|
| 48 |
+
# 4. Create topic groups for smoothness calculation (pairs of topics over time)
|
| 49 |
+
groups = []
|
| 50 |
+
for k in range(self.K):
|
| 51 |
+
time_pairs = []
|
| 52 |
+
for t in range(self.T - 1):
|
| 53 |
+
time_pairs.append([self.total_topics[k, t].tolist(), self.total_topics[k, t+1].tolist()])
|
| 54 |
+
groups.append(time_pairs)
|
| 55 |
+
self.group_topics = np.array(groups, dtype=object)
|
| 56 |
+
|
| 57 |
+
# 5. Create yearly topics (T, K, W) for TC/TD calculation
|
| 58 |
+
self.yearly_topics = self.total_topics.transpose(1, 0, 2)
|
| 59 |
+
|
| 60 |
+
# 6. Set parameters
|
| 61 |
+
self.topn = topn
|
| 62 |
+
self.coherence_type = coherence_type
|
| 63 |
+
|
| 64 |
+
def _compute_coherence(self, topics: List[List[str]]) -> List[float]:
|
| 65 |
+
cm = CoherenceModel(
|
| 66 |
+
topics=topics, texts=self.texts, dictionary=self.dictionary,
|
| 67 |
+
coherence=self.coherence_type, topn=self.topn
|
| 68 |
+
)
|
| 69 |
+
return cm.get_coherence_per_topic()
|
| 70 |
+
|
| 71 |
+
def _compute_coherence_ttc(self, topics: List[List[str]]) -> List[float]:
|
| 72 |
+
cm = CoherenceModel_ttc(
|
| 73 |
+
topics=topics, texts=self.texts, dictionary=self.dictionary,
|
| 74 |
+
coherence=self.coherence_type, topn=self.topn
|
| 75 |
+
)
|
| 76 |
+
return cm.get_coherence_per_topic()
|
| 77 |
+
|
| 78 |
+
def _topic_smoothness(self, topics: List[List[str]]) -> float:
|
| 79 |
+
K = len(topics)
|
| 80 |
+
if K <= 1:
|
| 81 |
+
return 1.0 # Or 0.0, depending on definition. A single topic has no other topic to be dissimilar to.
|
| 82 |
+
scores = []
|
| 83 |
+
for i, base in enumerate(topics):
|
| 84 |
+
base_set = set(base[:self.topn])
|
| 85 |
+
others = [other for j, other in enumerate(topics) if j != i]
|
| 86 |
+
if not others:
|
| 87 |
+
return 1.0
|
| 88 |
+
overlaps = [len(base_set & set(other[:self.topn])) / self.topn for other in others]
|
| 89 |
+
scores.append(sum(overlaps) / len(overlaps))
|
| 90 |
+
return float(sum(scores) / K)
|
| 91 |
+
|
| 92 |
+
def get_ttq_dataframe(self) -> pd.DataFrame:
|
| 93 |
+
"""Computes and returns a DataFrame with detailed TTQ metrics per topic chain."""
|
| 94 |
+
all_coh_scores, avg_coh_scores = [], []
|
| 95 |
+
for k in range(self.K):
|
| 96 |
+
coh_per_topic = self._compute_coherence_ttc(self.total_topics[k].tolist())
|
| 97 |
+
all_coh_scores.append(coh_per_topic)
|
| 98 |
+
avg_coh_scores.append(float(np.mean(coh_per_topic)))
|
| 99 |
+
|
| 100 |
+
all_smooth_scores, avg_smooth_scores = [], []
|
| 101 |
+
for k in range(self.K):
|
| 102 |
+
pair_scores = [self._topic_smoothness(pair) for pair in self.group_topics[k]]
|
| 103 |
+
all_smooth_scores.append(pair_scores)
|
| 104 |
+
avg_smooth_scores.append(float(np.mean(pair_scores)))
|
| 105 |
+
|
| 106 |
+
df = pd.DataFrame({
|
| 107 |
+
'topic_idx': list(range(self.K)),
|
| 108 |
+
'temporal_coherence': all_coh_scores,
|
| 109 |
+
'temporal_smoothness': all_smooth_scores,
|
| 110 |
+
'avg_temporal_coherence': avg_coh_scores,
|
| 111 |
+
'avg_temporal_smoothness': avg_smooth_scores
|
| 112 |
+
})
|
| 113 |
+
df['ttq_product'] = df['avg_temporal_coherence'] * df['avg_temporal_smoothness']
|
| 114 |
+
return df
|
| 115 |
+
|
| 116 |
+
def get_tq_dataframe(self) -> pd.DataFrame:
|
| 117 |
+
"""Computes and returns a DataFrame with detailed TQ metrics per time slice."""
|
| 118 |
+
all_coh, avg_coh, div = [], [], []
|
| 119 |
+
for t in range(self.T):
|
| 120 |
+
yearly_t_topics = self.yearly_topics[t].tolist()
|
| 121 |
+
coh_per_topic = self._compute_coherence(yearly_t_topics)
|
| 122 |
+
all_coh.append(coh_per_topic)
|
| 123 |
+
avg_coh.append(float(np.mean(coh_per_topic)))
|
| 124 |
+
div.append(1 - self._topic_smoothness(yearly_t_topics))
|
| 125 |
+
|
| 126 |
+
df = pd.DataFrame({
|
| 127 |
+
'year': list(range(self.T)),
|
| 128 |
+
'all_coherence': all_coh,
|
| 129 |
+
'avg_coherence': avg_coh,
|
| 130 |
+
'diversity': div
|
| 131 |
+
})
|
| 132 |
+
df['tq_product'] = df['avg_coherence'] * df['diversity']
|
| 133 |
+
return df
|
| 134 |
+
|
| 135 |
+
def get_ttc_score(self) -> float:
|
| 136 |
+
"""Calculates the overall Temporal Topic Coherence (TTC)."""
|
| 137 |
+
ttq_df = self.get_ttq_dataframe()
|
| 138 |
+
return ttq_df['avg_temporal_coherence'].mean()
|
| 139 |
+
|
| 140 |
+
def get_tts_score(self) -> float:
|
| 141 |
+
"""Calculates the overall Temporal Topic Smoothness (TTS)."""
|
| 142 |
+
ttq_df = self.get_ttq_dataframe()
|
| 143 |
+
return ttq_df['avg_temporal_smoothness'].mean()
|
| 144 |
+
|
| 145 |
+
def get_ttq_score(self) -> float:
|
| 146 |
+
"""Calculates the overall Temporal Topic Quality (TTQ)."""
|
| 147 |
+
ttq_df = self.get_ttq_dataframe()
|
| 148 |
+
return ttq_df['ttq_product'].mean()
|
| 149 |
+
|
| 150 |
+
def get_tc_score(self) -> float:
|
| 151 |
+
"""Calculates the overall yearly Topic Coherence (TC)."""
|
| 152 |
+
tq_df = self.get_tq_dataframe()
|
| 153 |
+
return tq_df['avg_coherence'].mean()
|
| 154 |
+
|
| 155 |
+
def get_td_score(self) -> float:
|
| 156 |
+
"""Calculates the overall yearly Topic Diversity (TD)."""
|
| 157 |
+
tq_df = self.get_tq_dataframe()
|
| 158 |
+
return tq_df['diversity'].mean()
|
| 159 |
+
|
| 160 |
+
def get_tq_score(self) -> float:
|
| 161 |
+
"""Calculates the overall yearly Topic Quality (TQ)."""
|
| 162 |
+
tq_df = self.get_tq_dataframe()
|
| 163 |
+
return tq_df['tq_product'].mean()
|
| 164 |
+
|
| 165 |
+
def get_dtq_summary(self) -> Dict[str, float]:
|
| 166 |
+
"""
|
| 167 |
+
Computes all dynamic topic quality metrics and returns them in a dictionary.
|
| 168 |
+
"""
|
| 169 |
+
ttq_df = self.get_ttq_dataframe()
|
| 170 |
+
tq_df = self.get_tq_dataframe()
|
| 171 |
+
summary = {
|
| 172 |
+
'TTC': ttq_df['avg_temporal_coherence'].mean(),
|
| 173 |
+
'TTS': ttq_df['avg_temporal_smoothness'].mean(),
|
| 174 |
+
'TTQ': ttq_df['ttq_product'].mean(),
|
| 175 |
+
'TC': tq_df['avg_coherence'].mean(),
|
| 176 |
+
'TD': tq_df['diversity'].mean(),
|
| 177 |
+
'TQ': tq_df['tq_product'].mean()
|
| 178 |
+
}
|
| 179 |
+
return summary
|
backend/inference/doc_retriever.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import html
|
| 2 |
+
import json
|
| 3 |
+
import re
|
| 4 |
+
import os
|
| 5 |
+
from hashlib import md5
|
| 6 |
+
|
| 7 |
+
def deduplicate_docs(collected_docs):
|
| 8 |
+
seen = set()
|
| 9 |
+
unique_docs = []
|
| 10 |
+
for doc in collected_docs:
|
| 11 |
+
# Prefer unique ID if available
|
| 12 |
+
key = doc.get("id", md5(doc["text"].encode()).hexdigest())
|
| 13 |
+
if key not in seen:
|
| 14 |
+
seen.add(key)
|
| 15 |
+
unique_docs.append(doc)
|
| 16 |
+
return unique_docs
|
| 17 |
+
|
| 18 |
+
def load_length_stats(length_stats_path):
|
| 19 |
+
"""
|
| 20 |
+
Loads length statistics from a JSON file for a given model path.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
path (str): Path to the model directory containing 'length_stats.json'.
|
| 24 |
+
|
| 25 |
+
Returns:
|
| 26 |
+
dict: A dictionary containing document length statistics.
|
| 27 |
+
"""
|
| 28 |
+
if not os.path.exists(length_stats_path):
|
| 29 |
+
raise FileNotFoundError(f"'length_stats.json' not found at: {length_stats_path}")
|
| 30 |
+
|
| 31 |
+
with open(length_stats_path, "r") as f:
|
| 32 |
+
length_stats = json.load(f)
|
| 33 |
+
|
| 34 |
+
return length_stats
|
| 35 |
+
|
| 36 |
+
def get_yearly_counts_for_word(index, word):
|
| 37 |
+
if word not in index:
|
| 38 |
+
print(f"[ERROR] Word '{word}' not found in index.")
|
| 39 |
+
return [], []
|
| 40 |
+
|
| 41 |
+
year_counts = index[word]
|
| 42 |
+
sorted_items = sorted((int(year), len(doc_ids)) for year, doc_ids in year_counts.items())
|
| 43 |
+
years, counts = zip(*sorted_items) if sorted_items else ([], [])
|
| 44 |
+
return list(years), list(counts)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def get_all_documents_for_word_year(index, docs_file_path, word, year):
|
| 48 |
+
"""
|
| 49 |
+
Returns all full documents (text + metadata) that contain a given word in a given year.
|
| 50 |
+
|
| 51 |
+
Parameters:
|
| 52 |
+
index (dict): Inverted index.
|
| 53 |
+
docs_file_path (str): Path to original jsonl corpus.
|
| 54 |
+
word (str): Word (unigram or bigram).
|
| 55 |
+
year (int): Year to retrieve docs for.
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
List[Dict]: List of documents with 'id', 'timestamp', and 'text'.
|
| 59 |
+
"""
|
| 60 |
+
year = int(year)
|
| 61 |
+
|
| 62 |
+
if word not in index or year not in index[word]:
|
| 63 |
+
return []
|
| 64 |
+
|
| 65 |
+
doc_ids = set(index[word][year])
|
| 66 |
+
results = []
|
| 67 |
+
|
| 68 |
+
try:
|
| 69 |
+
with open(docs_file_path, 'r', encoding='utf-8') as f:
|
| 70 |
+
for doc_id, line in enumerate(f):
|
| 71 |
+
if doc_id in doc_ids:
|
| 72 |
+
doc = json.loads(line)
|
| 73 |
+
results.append({
|
| 74 |
+
"id": doc_id,
|
| 75 |
+
"timestamp": doc.get("timestamp", "N/A"),
|
| 76 |
+
"text": doc["text"]
|
| 77 |
+
})
|
| 78 |
+
except Exception as e:
|
| 79 |
+
print(f"[ERROR] Could not load documents: {e}")
|
| 80 |
+
|
| 81 |
+
return results
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def get_documents_with_all_words_for_year(index, docs_path, words, year):
|
| 85 |
+
doc_sets = []
|
| 86 |
+
all_doc_occurrences = {}
|
| 87 |
+
|
| 88 |
+
for word in words:
|
| 89 |
+
word_docs = get_all_documents_for_word_year(index, docs_path, word, year)
|
| 90 |
+
doc_sets.append(set(doc["id"] for doc in word_docs))
|
| 91 |
+
for doc in word_docs:
|
| 92 |
+
all_doc_occurrences.setdefault(doc["id"], doc)
|
| 93 |
+
|
| 94 |
+
common_doc_ids = set.intersection(*doc_sets) if doc_sets else set()
|
| 95 |
+
return [all_doc_occurrences[doc_id] for doc_id in common_doc_ids]
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def get_intersection_doc_counts_by_year(index, docs_path, words, all_years):
|
| 99 |
+
year_counts = {}
|
| 100 |
+
for y in all_years:
|
| 101 |
+
docs = get_documents_with_all_words_for_year(index, docs_path, words, y)
|
| 102 |
+
year_counts[y] = len(docs)
|
| 103 |
+
return year_counts
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def extract_snippet(text, query, window=30):
|
| 107 |
+
"""
|
| 108 |
+
Return a short snippet around the first occurrence of the query word.
|
| 109 |
+
"""
|
| 110 |
+
pattern = re.compile(re.escape(query.replace('_', ' ')), re.IGNORECASE)
|
| 111 |
+
match = pattern.search(text)
|
| 112 |
+
if not match:
|
| 113 |
+
return text[:200] + "..."
|
| 114 |
+
|
| 115 |
+
start = max(match.start() - window, 0)
|
| 116 |
+
end = min(match.end() + window, len(text))
|
| 117 |
+
snippet = text[start:end].strip()
|
| 118 |
+
|
| 119 |
+
return f"...{snippet}..."
|
| 120 |
+
|
| 121 |
+
def highlight(text, query, highlight_color="#FFD54F"):
|
| 122 |
+
"""
|
| 123 |
+
Highlight all instances of the query term in text using a colored <mark> tag.
|
| 124 |
+
"""
|
| 125 |
+
escaped_query = re.escape(query.replace('_', ' '))
|
| 126 |
+
pattern = re.compile(f"({escaped_query})", flags=re.IGNORECASE)
|
| 127 |
+
|
| 128 |
+
def replacer(match):
|
| 129 |
+
matched_text = html.escape(match.group(1))
|
| 130 |
+
return f"<mark style='background-color:{highlight_color}; color:black;'>{matched_text}</mark>"
|
| 131 |
+
|
| 132 |
+
return pattern.sub(replacer, html.escape(text))
|
| 133 |
+
|
| 134 |
+
def highlight_words(text, query_words, highlight_color="#24F31D", lemma_to_forms=None):
|
| 135 |
+
"""
|
| 136 |
+
Highlight all surface forms of each query lemma in the text using a colored <mark> tag.
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
text (str): The input raw document text.
|
| 140 |
+
query_words (List[str]): Lemmatized query tokens to highlight.
|
| 141 |
+
highlight_color (str): Color to use for highlighting.
|
| 142 |
+
lemma_to_forms (Dict[str, Set[str]]): Maps a lemma to its surface forms.
|
| 143 |
+
"""
|
| 144 |
+
# Escape HTML special characters first
|
| 145 |
+
escaped_text = html.escape(text)
|
| 146 |
+
|
| 147 |
+
# Expand query words to include all surface forms
|
| 148 |
+
expanded_forms = set()
|
| 149 |
+
for lemma in query_words:
|
| 150 |
+
if lemma_to_forms and lemma in lemma_to_forms:
|
| 151 |
+
expanded_forms.update(lemma_to_forms[lemma])
|
| 152 |
+
else:
|
| 153 |
+
expanded_forms.add(lemma) # Fallback if map is missing
|
| 154 |
+
|
| 155 |
+
# Sort by length to avoid partial overlaps (e.g., "run" before "running")
|
| 156 |
+
sorted_queries = sorted(expanded_forms, key=lambda w: -len(w))
|
| 157 |
+
|
| 158 |
+
for word in sorted_queries:
|
| 159 |
+
# Match full word, case insensitive
|
| 160 |
+
pattern = re.compile(rf'\b({re.escape(word)})\b', flags=re.IGNORECASE)
|
| 161 |
+
|
| 162 |
+
def replacer(match):
|
| 163 |
+
matched_text = match.group(1)
|
| 164 |
+
return f"<mark style='background-color:{highlight_color}; color:black;'>{matched_text}</mark>"
|
| 165 |
+
|
| 166 |
+
escaped_text = pattern.sub(replacer, escaped_text)
|
| 167 |
+
|
| 168 |
+
return escaped_text
|
| 169 |
+
|
| 170 |
+
def get_docs_by_ids(docs_file_path, doc_ids):
|
| 171 |
+
"""
|
| 172 |
+
Efficiently retrieves specific documents from a .jsonl file by their line number (ID).
|
| 173 |
+
|
| 174 |
+
This function reads the file line-by-line and only parses the lines that match
|
| 175 |
+
the requested document IDs, avoiding loading the entire file into memory.
|
| 176 |
+
|
| 177 |
+
Args:
|
| 178 |
+
docs_file_path (str): The path to the documents.jsonl file.
|
| 179 |
+
doc_ids (list or set): A collection of document IDs (0-indexed line numbers) to retrieve.
|
| 180 |
+
|
| 181 |
+
Returns:
|
| 182 |
+
list[dict]: A list of document dictionaries that were found. Each dictionary
|
| 183 |
+
is augmented with an 'id' key corresponding to its line number.
|
| 184 |
+
"""
|
| 185 |
+
# Use a set for efficient O(1) lookups.
|
| 186 |
+
doc_ids_to_find = set(doc_ids)
|
| 187 |
+
found_docs = {}
|
| 188 |
+
|
| 189 |
+
if not doc_ids_to_find:
|
| 190 |
+
return []
|
| 191 |
+
|
| 192 |
+
try:
|
| 193 |
+
with open(docs_file_path, 'r', encoding='utf-8') as f:
|
| 194 |
+
for i, line in enumerate(f):
|
| 195 |
+
# If the current line number is one we're looking for
|
| 196 |
+
if i in doc_ids_to_find:
|
| 197 |
+
try:
|
| 198 |
+
doc = json.loads(line)
|
| 199 |
+
# Explicitly add the line number as the 'id'
|
| 200 |
+
doc['id'] = i
|
| 201 |
+
found_docs[i] = doc
|
| 202 |
+
# Optimization: stop reading the file once all docs are found
|
| 203 |
+
if len(found_docs) == len(doc_ids_to_find):
|
| 204 |
+
break
|
| 205 |
+
except json.JSONDecodeError:
|
| 206 |
+
# Skip malformed lines but inform the user
|
| 207 |
+
print(f"[WARNING] Skipping malformed JSON on line {i+1} in {docs_file_path}")
|
| 208 |
+
continue
|
| 209 |
+
|
| 210 |
+
except FileNotFoundError:
|
| 211 |
+
print(f"[ERROR] Document file not found at: {docs_file_path}")
|
| 212 |
+
return []
|
| 213 |
+
except Exception as e:
|
| 214 |
+
print(f"[ERROR] An unexpected error occurred while reading documents: {e}")
|
| 215 |
+
return []
|
| 216 |
+
|
| 217 |
+
# Return the documents in the same order as the original doc_ids list
|
| 218 |
+
# This ensures consistency for downstream processing.
|
| 219 |
+
return [found_docs[doc_id] for doc_id in doc_ids if doc_id in found_docs]
|
backend/inference/indexing_utils.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import re
|
| 4 |
+
import spacy
|
| 5 |
+
from collections import defaultdict
|
| 6 |
+
|
| 7 |
+
# Load spaCy once
|
| 8 |
+
nlp = spacy.load("en_core_web_sm", disable=["parser", "ner"])
|
| 9 |
+
|
| 10 |
+
def tokenize(text):
|
| 11 |
+
return re.findall(r"\b\w+\b", text.lower())
|
| 12 |
+
|
| 13 |
+
def has_bigram(tokens, bigram):
|
| 14 |
+
parts = bigram.split('_')
|
| 15 |
+
for i in range(len(tokens) - len(parts) + 1):
|
| 16 |
+
if tokens[i:i + len(parts)] == parts:
|
| 17 |
+
return True
|
| 18 |
+
return False
|
| 19 |
+
|
| 20 |
+
def build_inverse_lemma_map(docs_file_path, cache_path=None):
|
| 21 |
+
"""
|
| 22 |
+
Build or load a mapping from lemma -> set of surface forms seen in corpus.
|
| 23 |
+
If cache_path is provided and exists, loads from it.
|
| 24 |
+
Else builds from scratch and saves to cache_path.
|
| 25 |
+
"""
|
| 26 |
+
if cache_path and os.path.exists(cache_path):
|
| 27 |
+
print(f"[INFO] Loading cached lemma_to_forms from {cache_path}")
|
| 28 |
+
with open(cache_path, "r", encoding="utf-8") as f:
|
| 29 |
+
raw_map = json.load(f)
|
| 30 |
+
return {lemma: set(forms) for lemma, forms in raw_map.items()}
|
| 31 |
+
|
| 32 |
+
print(f"[INFO] Building inverse lemma map from {docs_file_path}...")
|
| 33 |
+
lemma_to_forms = defaultdict(set)
|
| 34 |
+
|
| 35 |
+
with open(docs_file_path, 'r', encoding='utf-8') as f:
|
| 36 |
+
for line in f:
|
| 37 |
+
doc = json.loads(line)
|
| 38 |
+
tokens = tokenize(doc['text'])
|
| 39 |
+
spacy_doc = nlp(" ".join(tokens))
|
| 40 |
+
for token in spacy_doc:
|
| 41 |
+
lemma_to_forms[token.lemma_].add(token.text.lower())
|
| 42 |
+
|
| 43 |
+
if cache_path:
|
| 44 |
+
print(f"[INFO] Saving lemma_to_forms to {cache_path}")
|
| 45 |
+
os.makedirs(os.path.dirname(cache_path), exist_ok=True)
|
| 46 |
+
with open(cache_path, "w", encoding="utf-8") as f:
|
| 47 |
+
json.dump({k: list(v) for k, v in lemma_to_forms.items()}, f, indent=2)
|
| 48 |
+
|
| 49 |
+
return lemma_to_forms
|
| 50 |
+
|
| 51 |
+
def build_inverted_index(docs_file_path, vocab_set, lemma_map_path=None):
|
| 52 |
+
vocab_unigrams = {w for w in vocab_set if '_' not in w}
|
| 53 |
+
vocab_bigrams = {w for w in vocab_set if '_' in w}
|
| 54 |
+
|
| 55 |
+
# Load or build lemma map
|
| 56 |
+
lemma_to_forms = build_inverse_lemma_map(docs_file_path, cache_path=lemma_map_path)
|
| 57 |
+
|
| 58 |
+
index = defaultdict(lambda: defaultdict(list))
|
| 59 |
+
docs = []
|
| 60 |
+
global_seen_words = set()
|
| 61 |
+
|
| 62 |
+
with open(docs_file_path, 'r', encoding='utf-8') as f:
|
| 63 |
+
for doc_id, line in enumerate(f):
|
| 64 |
+
doc = json.loads(line)
|
| 65 |
+
text = doc['text']
|
| 66 |
+
timestamp = int(doc['timestamp'])
|
| 67 |
+
docs.append({"text": text, "timestamp": timestamp})
|
| 68 |
+
|
| 69 |
+
tokens = tokenize(text)
|
| 70 |
+
token_set = set(tokens)
|
| 71 |
+
seen_words = set()
|
| 72 |
+
|
| 73 |
+
# Match all lemma queries using surface forms
|
| 74 |
+
for lemma in vocab_unigrams:
|
| 75 |
+
surface_forms = lemma_to_forms.get(lemma, set())
|
| 76 |
+
if token_set & surface_forms:
|
| 77 |
+
index[lemma][timestamp].append(doc_id)
|
| 78 |
+
seen_words.add(lemma)
|
| 79 |
+
|
| 80 |
+
for bigram in vocab_bigrams:
|
| 81 |
+
if bigram not in seen_words and has_bigram(tokens, bigram):
|
| 82 |
+
index[bigram][timestamp].append(doc_id)
|
| 83 |
+
seen_words.add(bigram)
|
| 84 |
+
|
| 85 |
+
global_seen_words.update(seen_words)
|
| 86 |
+
|
| 87 |
+
if (doc_id + 1) % 500 == 0:
|
| 88 |
+
missing = vocab_set - global_seen_words
|
| 89 |
+
print(f"[INFO] After {doc_id+1} docs, {len(missing)} vocab words still not seen.")
|
| 90 |
+
print("Example missing words:", list(missing)[:5])
|
| 91 |
+
|
| 92 |
+
missing_final = vocab_set - global_seen_words
|
| 93 |
+
if missing_final:
|
| 94 |
+
print(f"[WARNING] {len(missing_final)} vocab words were never found in any document.")
|
| 95 |
+
print("Examples:", list(missing_final)[:10])
|
| 96 |
+
|
| 97 |
+
return index, docs, lemma_to_forms
|
| 98 |
+
|
| 99 |
+
def save_index_to_disk(index, index_path):
|
| 100 |
+
index_clean = {
|
| 101 |
+
word: {str(ts): doc_ids for ts, doc_ids in ts_dict.items()}
|
| 102 |
+
for word, ts_dict in index.items()
|
| 103 |
+
}
|
| 104 |
+
os.makedirs(os.path.dirname(index_path), exist_ok=True)
|
| 105 |
+
with open(index_path, "w", encoding='utf-8') as f:
|
| 106 |
+
json.dump(index_clean, f, ensure_ascii=False)
|
| 107 |
+
|
| 108 |
+
def load_index_from_disk(index_path):
|
| 109 |
+
with open(index_path, 'r', encoding='utf-8') as f:
|
| 110 |
+
raw_index = json.load(f)
|
| 111 |
+
|
| 112 |
+
index = defaultdict(lambda: defaultdict(list))
|
| 113 |
+
for word, ts_dict in raw_index.items():
|
| 114 |
+
for ts, doc_ids in ts_dict.items():
|
| 115 |
+
index[word][int(ts)] = doc_ids
|
| 116 |
+
|
| 117 |
+
return index
|
| 118 |
+
|
| 119 |
+
def load_docs(docs_file_path):
|
| 120 |
+
docs = []
|
| 121 |
+
with open(docs_file_path, 'r', encoding='utf-8') as f:
|
| 122 |
+
for line in f:
|
| 123 |
+
doc = json.loads(line)
|
| 124 |
+
docs.append({
|
| 125 |
+
"text": doc["text"],
|
| 126 |
+
"timestamp": int(doc["timestamp"])
|
| 127 |
+
})
|
| 128 |
+
return docs
|
| 129 |
+
|
| 130 |
+
def load_index(docs_file_path, vocab, index_path=None, lemma_map_path=None):
|
| 131 |
+
if index_path and os.path.exists(index_path):
|
| 132 |
+
index = load_index_from_disk(index_path)
|
| 133 |
+
docs = load_docs(docs_file_path)
|
| 134 |
+
lemma_to_forms = build_inverse_lemma_map(docs_file_path, cache_path=lemma_map_path)
|
| 135 |
+
return index, docs, lemma_to_forms
|
| 136 |
+
|
| 137 |
+
index, docs, lemma_to_forms = build_inverted_index(
|
| 138 |
+
docs_file_path,
|
| 139 |
+
set(vocab),
|
| 140 |
+
lemma_map_path=lemma_map_path
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
if index_path:
|
| 144 |
+
save_index_to_disk(index, index_path)
|
| 145 |
+
|
| 146 |
+
return index, docs, lemma_to_forms
|
backend/inference/peak_detector.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from scipy.signal import find_peaks
|
| 3 |
+
|
| 4 |
+
def detect_peaks(trend, prominence=0.001, distance=2):
|
| 5 |
+
"""
|
| 6 |
+
Detect peaks in a word's trend over time.
|
| 7 |
+
|
| 8 |
+
Args:
|
| 9 |
+
trend: List or np.array of floats (word importance over time)
|
| 10 |
+
prominence: Required prominence of peaks (tune based on scale)
|
| 11 |
+
distance: Minimum distance between peaks
|
| 12 |
+
|
| 13 |
+
Returns:
|
| 14 |
+
List of indices (timestamps) where peaks occur
|
| 15 |
+
"""
|
| 16 |
+
trend = np.array(trend)
|
| 17 |
+
peaks, _ = find_peaks(trend, prominence=prominence, distance=distance)
|
| 18 |
+
return peaks.tolist()
|
backend/inference/process_beta.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import json
|
| 3 |
+
|
| 4 |
+
def load_beta_matrix(beta_path: str, vocab_path: str):
|
| 5 |
+
"""
|
| 6 |
+
Loads the beta matrix (T x K x V) and vocab list.
|
| 7 |
+
|
| 8 |
+
Returns:
|
| 9 |
+
beta: np.ndarray of shape (T, K, V)
|
| 10 |
+
vocab: list of words
|
| 11 |
+
"""
|
| 12 |
+
beta = np.load(beta_path) # shape: T x K x V
|
| 13 |
+
with open(vocab_path, 'r') as f:
|
| 14 |
+
vocab = [line.strip() for line in f.readlines()]
|
| 15 |
+
return beta, vocab
|
| 16 |
+
|
| 17 |
+
def get_top_words_at_time(beta, vocab, topic_id, time, top_n):
|
| 18 |
+
topic_beta = beta[time, topic_id, :]
|
| 19 |
+
top_indices = topic_beta.argsort()[-top_n:][::-1]
|
| 20 |
+
return [vocab[i] for i in top_indices]
|
| 21 |
+
|
| 22 |
+
def get_top_words_over_time(beta, vocab, topic_id, top_n):
|
| 23 |
+
topic_beta = beta[:, topic_id, :]
|
| 24 |
+
mean_beta = topic_beta.mean(axis=0)
|
| 25 |
+
top_indices = mean_beta.argsort()[-top_n:][::-1]
|
| 26 |
+
return [vocab[i] for i in top_indices]
|
| 27 |
+
|
| 28 |
+
def load_time_labels(time2id_path):
|
| 29 |
+
with open(time2id_path, 'r') as f:
|
| 30 |
+
time2id = json.load(f)
|
| 31 |
+
# Invert and sort by id
|
| 32 |
+
id2time = {v: k for k, v in time2id.items()}
|
| 33 |
+
return [id2time[i] for i in sorted(id2time)]
|
backend/inference/word_selector.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from scipy.special import softmax
|
| 3 |
+
|
| 4 |
+
def get_interesting_words(beta, vocab, topic_id, top_k_final=10, restrict_to=None):
|
| 5 |
+
"""
|
| 6 |
+
Suggests interesting words by prioritizing "bursty" or "emerging" terms,
|
| 7 |
+
making it effective at capturing important low-probability words.
|
| 8 |
+
|
| 9 |
+
This algorithm focuses on the ratio of a word's peak probability to its mean,
|
| 10 |
+
capturing words that show significant growth or have a sudden moment of high
|
| 11 |
+
relevance, even if their average probability is low.
|
| 12 |
+
|
| 13 |
+
Parameters:
|
| 14 |
+
- beta: np.ndarray (T, K, V) - Topic-word distributions for each timestamp.
|
| 15 |
+
- vocab: list of V words - The vocabulary.
|
| 16 |
+
- topic_id: int - The ID of the topic to analyze.
|
| 17 |
+
- top_k_final: int - The number of words to return.
|
| 18 |
+
- restrict_to: optional list of str - Restricts scoring to a subset of words.
|
| 19 |
+
|
| 20 |
+
Returns:
|
| 21 |
+
- list of top_k_final interesting words (strings).
|
| 22 |
+
"""
|
| 23 |
+
T, K, V = beta.shape
|
| 24 |
+
|
| 25 |
+
# --- 1. Detect whether softmax is needed ---
|
| 26 |
+
row_sums = beta.sum(axis=2)
|
| 27 |
+
is_prob_dist = np.allclose(row_sums, 1.0, atol=1e-2)
|
| 28 |
+
|
| 29 |
+
if not is_prob_dist:
|
| 30 |
+
print("🔁 Beta is not normalized — applying softmax across words per topic.")
|
| 31 |
+
beta = softmax(beta / 1e-3, axis=2)
|
| 32 |
+
|
| 33 |
+
# --- 2. Now extract normalized topic slice ---
|
| 34 |
+
topic_beta = beta[:, topic_id, :] # Shape: (T, V)
|
| 35 |
+
|
| 36 |
+
# Mean and Peak probability within the topic for each word
|
| 37 |
+
mean_topic = topic_beta.mean(axis=0) # Shape: (V,)
|
| 38 |
+
peak_topic = topic_beta.max(axis=0) # Shape: (V,)
|
| 39 |
+
|
| 40 |
+
# Corpus-wide mean for baseline comparison
|
| 41 |
+
mean_all = beta.mean(axis=(0, 1)) # Shape: (V,)
|
| 42 |
+
|
| 43 |
+
# Epsilon to prevent division by zero for words that never appear
|
| 44 |
+
epsilon = 1e-9
|
| 45 |
+
|
| 46 |
+
# --- 3. Calculate the three core components of the new score ---
|
| 47 |
+
|
| 48 |
+
# a) Burstiness Score: How much a word's peak stands out from its own average.
|
| 49 |
+
# This is the key to finding "surprising" words.
|
| 50 |
+
burstiness_score = peak_topic / (mean_topic + epsilon)
|
| 51 |
+
|
| 52 |
+
# b) Peak Specificity: How much the word's peak in this topic stands out from
|
| 53 |
+
# its average presence in the entire corpus.
|
| 54 |
+
peak_specificity_score = peak_topic / (mean_all + epsilon)
|
| 55 |
+
|
| 56 |
+
# c) Uniqueness Score (same as before): Penalizes words active in many topics.
|
| 57 |
+
active_in_topics = (beta > 1e-5).mean(axis=0) # Shape: (K, V)
|
| 58 |
+
idf_like = np.log((K + 1) / (active_in_topics.sum(axis=0) + 1)) # Shape: (V,)
|
| 59 |
+
|
| 60 |
+
# --- 4. Compute Final Interestingness Score ---
|
| 61 |
+
# This score is high for words that are unique, have a high peak relative
|
| 62 |
+
# to their baseline, and whose peak is an unusual event for that word.
|
| 63 |
+
final_scores = burstiness_score * peak_specificity_score * idf_like
|
| 64 |
+
|
| 65 |
+
# --- 5. Rank and select top words ---
|
| 66 |
+
if restrict_to is not None:
|
| 67 |
+
restrict_set = set(restrict_to)
|
| 68 |
+
word_indices = [i for i, w in enumerate(vocab) if w in restrict_set]
|
| 69 |
+
else:
|
| 70 |
+
word_indices = np.arange(V)
|
| 71 |
+
|
| 72 |
+
if not word_indices:
|
| 73 |
+
return []
|
| 74 |
+
|
| 75 |
+
# Rank the filtered indices by the final score in descending order
|
| 76 |
+
sorted_indices = sorted(word_indices, key=lambda i: -final_scores[i])
|
| 77 |
+
|
| 78 |
+
return [vocab[i] for i in sorted_indices[:top_k_final]]
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def get_word_trend(beta, vocab, word, topic_id):
|
| 82 |
+
"""
|
| 83 |
+
Get the time trend of a word's probability under a specific topic.
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
beta: np.ndarray of shape (T, K, V)
|
| 87 |
+
vocab: list of vocab words
|
| 88 |
+
word: word to search
|
| 89 |
+
topic_id: index of topic to inspect (0 <= topic_id < K)
|
| 90 |
+
|
| 91 |
+
Returns:
|
| 92 |
+
List of word probabilities over time (length T)
|
| 93 |
+
"""
|
| 94 |
+
T, K, V = beta.shape
|
| 95 |
+
if word not in vocab:
|
| 96 |
+
raise ValueError(f"Word '{word}' not found in vocab.")
|
| 97 |
+
if not (0 <= topic_id < K):
|
| 98 |
+
raise ValueError(f"Invalid topic_id {topic_id}. Must be between 0 and {K - 1}.")
|
| 99 |
+
|
| 100 |
+
word_index = vocab.index(word)
|
| 101 |
+
trend = beta[:, topic_id, word_index] # shape (T,)
|
| 102 |
+
return trend.tolist()
|
backend/llm/custom_gemini.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 2 |
+
from langchain_core.messages import AIMessage, HumanMessage
|
| 3 |
+
from langchain_core.language_models.chat_models import BaseChatModel
|
| 4 |
+
from typing import List
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class ChatGemini(BaseChatModel):
|
| 8 |
+
def __init__(self, api_key: str, model: str = "gemini-pro", temperature: float = 0.7):
|
| 9 |
+
self.model = model
|
| 10 |
+
self.temperature = temperature
|
| 11 |
+
self.api_key = api_key
|
| 12 |
+
self.client = ChatGoogleGenerativeAI(
|
| 13 |
+
model=model,
|
| 14 |
+
temperature=temperature,
|
| 15 |
+
google_api_key=api_key
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
def _generate(self, messages: List, stop: List[str] = None):
|
| 19 |
+
# Convert LangChain messages to string
|
| 20 |
+
prompt = "\n".join(
|
| 21 |
+
msg.content for msg in messages if isinstance(msg, (HumanMessage, AIMessage))
|
| 22 |
+
)
|
| 23 |
+
response = self.client.invoke(prompt)
|
| 24 |
+
return response
|
| 25 |
+
|
| 26 |
+
@property
|
| 27 |
+
def _llm_type(self) -> str:
|
| 28 |
+
return "gemini"
|
backend/llm/custom_mistral.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_core.language_models.chat_models import BaseChatModel
|
| 2 |
+
from langchain_core.messages import HumanMessage, AIMessage
|
| 3 |
+
from langchain_core.outputs import ChatResult, ChatGeneration
|
| 4 |
+
import requests
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
class ChatMistral(BaseChatModel):
|
| 8 |
+
def __init__(self, hf_token=None, model_url=None):
|
| 9 |
+
self.hf_token = hf_token or os.getenv("HF_TOKEN")
|
| 10 |
+
self.model_url = model_url or "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.1"
|
| 11 |
+
self.headers = {"Authorization": f"Bearer {self.hf_token}"}
|
| 12 |
+
|
| 13 |
+
def _call(self, prompt: str) -> str:
|
| 14 |
+
response = requests.post(
|
| 15 |
+
self.model_url,
|
| 16 |
+
headers=self.headers,
|
| 17 |
+
json={"inputs": prompt, "parameters": {"max_new_tokens": 256}},
|
| 18 |
+
)
|
| 19 |
+
return response.json()[0]["generated_text"]
|
| 20 |
+
|
| 21 |
+
def invoke(self, messages, **kwargs):
|
| 22 |
+
prompt = "\n".join([msg.content for msg in messages if isinstance(msg, HumanMessage)])
|
| 23 |
+
response = self._call(prompt)
|
| 24 |
+
return AIMessage(content=response)
|
| 25 |
+
|
| 26 |
+
def _generate(self, messages, stop=None, **kwargs) -> ChatResult:
|
| 27 |
+
return ChatResult(generations=[ChatGeneration(message=self.invoke(messages))])
|
backend/llm/llm_router.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_anthropic import ChatAnthropic
|
| 2 |
+
from backend.llm.custom_mistral import ChatMistral
|
| 3 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 4 |
+
from langchain_openai import ChatOpenAI
|
| 5 |
+
import os
|
| 6 |
+
import google.auth.transport.requests
|
| 7 |
+
import requests
|
| 8 |
+
|
| 9 |
+
resp = requests.get("https://www.google.com", proxies={
|
| 10 |
+
"http": os.getenv("http_proxy"),
|
| 11 |
+
"https": os.getenv("https_proxy")
|
| 12 |
+
})
|
| 13 |
+
|
| 14 |
+
def list_supported_models(provider=None):
|
| 15 |
+
if provider == "OpenAI":
|
| 16 |
+
return ["gpt-4.1-nano", "gpt-4o-mini"]
|
| 17 |
+
elif provider == "Anthropic":
|
| 18 |
+
return ["claude-3-opus-20240229", "claude-3-sonnet-20240229"]
|
| 19 |
+
elif provider == "Gemini":
|
| 20 |
+
return ["gemini-2.0-flash-lite", "gemini-1.5-flash"]
|
| 21 |
+
elif provider == "Mistral":
|
| 22 |
+
return ["mistral-small", "mistral-medium"]
|
| 23 |
+
else:
|
| 24 |
+
# Default fallback: all models grouped by provider
|
| 25 |
+
return {
|
| 26 |
+
"OpenAI": ["gpt-4.1-nano", "gpt-4o-mini"],
|
| 27 |
+
"Anthropic": ["claude-3-opus-20240229", "claude-3-sonnet-20240229"],
|
| 28 |
+
"Gemini": ["gemini-2.0-flash-lite", "gemini-1.5-flash"],
|
| 29 |
+
"Mistral": ["mistral-small", "mistral-medium"]
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def get_llm(provider: str, model: str, api_key: str = None):
|
| 34 |
+
if provider == "OpenAI":
|
| 35 |
+
api_key = api_key or os.getenv("OPENAI_API_KEY")
|
| 36 |
+
if not api_key:
|
| 37 |
+
raise ValueError("Missing OpenAI API key.")
|
| 38 |
+
return ChatOpenAI(model_name=model, temperature=0, openai_api_key=api_key)
|
| 39 |
+
|
| 40 |
+
elif provider == "Anthropic":
|
| 41 |
+
api_key = api_key or os.getenv("ANTHROPIC_API_KEY")
|
| 42 |
+
if not api_key:
|
| 43 |
+
raise ValueError("Missing Anthropic API key.")
|
| 44 |
+
return ChatAnthropic(model=model, temperature=0, anthropic_api_key=api_key)
|
| 45 |
+
|
| 46 |
+
elif provider == "Gemini":
|
| 47 |
+
api_key = api_key or os.getenv("GEMINI_API_KEY")
|
| 48 |
+
if not api_key:
|
| 49 |
+
raise ValueError("Missing Gemini API key.")
|
| 50 |
+
# --- Patch: Set proxy if available ---
|
| 51 |
+
if "HTTP_PROXY" in os.environ or "http_proxy" in os.environ:
|
| 52 |
+
|
| 53 |
+
proxies = {
|
| 54 |
+
"http": os.getenv("http_proxy") or os.getenv("HTTP_PROXY"),
|
| 55 |
+
"https": os.getenv("https_proxy") or os.getenv("HTTPS_PROXY")
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
google.auth.transport.requests.requests.Request = lambda *args, **kwargs: requests.Request(
|
| 59 |
+
*args, **kwargs, proxies=proxies
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
return ChatGoogleGenerativeAI(model=model, temperature=0, google_api_key=api_key)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
elif provider == "Mistral":
|
| 66 |
+
api_key = api_key or os.getenv("MISTRAL_API_KEY")
|
| 67 |
+
if not api_key:
|
| 68 |
+
raise ValueError("Missing Mistral API key.")
|
| 69 |
+
return ChatMistral(model=model, temperature=0, mistral_api_key=api_key)
|
| 70 |
+
|
| 71 |
+
else:
|
| 72 |
+
raise ValueError(f"Unsupported provider: {provider}")
|
| 73 |
+
|
backend/llm_utils/label_generator.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from hashlib import sha256
|
| 2 |
+
import json
|
| 3 |
+
from langchain_core.prompts import ChatPromptTemplate
|
| 4 |
+
from langchain_core.output_parsers import StrOutputParser
|
| 5 |
+
from typing import Optional
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
#get_top_words_at_time
|
| 9 |
+
from backend.inference.process_beta import get_top_words_at_time
|
| 10 |
+
|
| 11 |
+
def label_topic_temporal(word_trajectory_str: str, llm, cache_path: Optional[str] = None) -> str:
|
| 12 |
+
"""
|
| 13 |
+
Label a dynamic topic by providing the LLM with the top words over time.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
word_trajectory_str (str): Formatted keyword evolution string.
|
| 17 |
+
llm: LangChain-compatible LLM instance.
|
| 18 |
+
cache_path (Optional[str]): Path to the cache file (JSON).
|
| 19 |
+
|
| 20 |
+
Returns:
|
| 21 |
+
str: Short label for the topic.
|
| 22 |
+
"""
|
| 23 |
+
topic_key = sha256(word_trajectory_str.encode()).hexdigest()
|
| 24 |
+
|
| 25 |
+
# Load cache
|
| 26 |
+
if cache_path is not None and os.path.exists(cache_path):
|
| 27 |
+
with open(cache_path, "r") as f:
|
| 28 |
+
label_cache = json.load(f)
|
| 29 |
+
else:
|
| 30 |
+
label_cache = {}
|
| 31 |
+
|
| 32 |
+
# Return cached result
|
| 33 |
+
if topic_key in label_cache:
|
| 34 |
+
return label_cache[topic_key]
|
| 35 |
+
|
| 36 |
+
# Prompt template
|
| 37 |
+
prompt = ChatPromptTemplate.from_template(
|
| 38 |
+
"You are an expert in topic modeling and temporal data analysis. "
|
| 39 |
+
"Given the top words for a topic across multiple time points, your task is to return a short, specific, descriptive topic label. "
|
| 40 |
+
"Avoid vague, generic, or overly broad labels. Focus on consistent themes in the top words over time. "
|
| 41 |
+
"Use concise noun phrases, 2–5 words max. Do NOT include any explanation, justification, or extra output.\n\n"
|
| 42 |
+
"Top words over time:\n{trajectory}\n\n"
|
| 43 |
+
"Return ONLY the label (no quotes, no extra text):"
|
| 44 |
+
)
|
| 45 |
+
chain = prompt | llm | StrOutputParser()
|
| 46 |
+
|
| 47 |
+
try:
|
| 48 |
+
label = chain.invoke({"trajectory": word_trajectory_str}).strip()
|
| 49 |
+
except Exception as e:
|
| 50 |
+
label = "Unknown Topic"
|
| 51 |
+
print(f"[Labeling Error] {e}")
|
| 52 |
+
|
| 53 |
+
# Update cache and save
|
| 54 |
+
label_cache[topic_key] = label
|
| 55 |
+
if cache_path is not None:
|
| 56 |
+
os.makedirs(os.path.dirname(cache_path), exist_ok=True)
|
| 57 |
+
with open(cache_path, "w") as f:
|
| 58 |
+
json.dump(label_cache, f, indent=2)
|
| 59 |
+
|
| 60 |
+
return label
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def get_topic_labels(beta, vocab, time_labels, llm, cache_path):
|
| 64 |
+
topic_labels = {}
|
| 65 |
+
for topic_id in range(beta.shape[1]):
|
| 66 |
+
word_trajectory_str = "\n".join([
|
| 67 |
+
f"{time_labels[t]}: {', '.join(get_top_words_at_time(beta, vocab, topic_id, t, top_n=10))}"
|
| 68 |
+
for t in range(beta.shape[0])
|
| 69 |
+
])
|
| 70 |
+
label = label_topic_temporal(word_trajectory_str, llm=llm, cache_path=cache_path)
|
| 71 |
+
topic_labels[topic_id] = label
|
| 72 |
+
return topic_labels
|
backend/llm_utils/summarizer.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import hashlib
|
| 2 |
+
import numpy as np
|
| 3 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
| 4 |
+
from sentence_transformers import SentenceTransformer
|
| 5 |
+
import faiss
|
| 6 |
+
|
| 7 |
+
from langchain.prompts import ChatPromptTemplate
|
| 8 |
+
from langchain.docstore.document import Document
|
| 9 |
+
from langchain.memory import ConversationBufferMemory
|
| 10 |
+
from langchain.chains import ConversationChain
|
| 11 |
+
|
| 12 |
+
import os
|
| 13 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# --- MMR Utilities ---
|
| 17 |
+
def build_mmr_index(docs):
|
| 18 |
+
texts = [doc['text'] for doc in docs if 'text' in doc]
|
| 19 |
+
documents = [Document(page_content=text) for text in texts]
|
| 20 |
+
|
| 21 |
+
model = SentenceTransformer("all-MiniLM-L6-v2")
|
| 22 |
+
embeddings = model.encode([doc.page_content for doc in documents], convert_to_numpy=True)
|
| 23 |
+
faiss.normalize_L2(embeddings)
|
| 24 |
+
|
| 25 |
+
index = faiss.IndexFlatIP(embeddings.shape[1])
|
| 26 |
+
index.add(embeddings)
|
| 27 |
+
|
| 28 |
+
return model, index, embeddings, documents
|
| 29 |
+
|
| 30 |
+
def get_mmr_sample(model, index, embeddings, documents, query, k=15, lambda_mult=0.7):
|
| 31 |
+
if len(documents) == 0:
|
| 32 |
+
print("Warning: No documents available, returning empty list.")
|
| 33 |
+
return []
|
| 34 |
+
|
| 35 |
+
if len(documents) <= k:
|
| 36 |
+
print(f"Warning: Only {len(documents)} documents available, returning all.")
|
| 37 |
+
return documents
|
| 38 |
+
|
| 39 |
+
else:
|
| 40 |
+
query_vec = model.encode(query, convert_to_numpy=True)
|
| 41 |
+
query_vec = query_vec / np.linalg.norm(query_vec)
|
| 42 |
+
|
| 43 |
+
# Get candidate indices from FAISS (k * 4 or less if not enough documents)
|
| 44 |
+
num_candidates = min(k * 4, len(documents))
|
| 45 |
+
D, I = index.search(np.expand_dims(query_vec, axis=0), num_candidates)
|
| 46 |
+
candidate_idxs = list(I[0])
|
| 47 |
+
|
| 48 |
+
selected = []
|
| 49 |
+
while len(selected) < k and candidate_idxs:
|
| 50 |
+
if not selected:
|
| 51 |
+
selected.append(candidate_idxs.pop(0))
|
| 52 |
+
continue
|
| 53 |
+
|
| 54 |
+
mmr_scores = []
|
| 55 |
+
for idx in candidate_idxs:
|
| 56 |
+
relevance = cosine_similarity([query_vec], [embeddings[idx]])[0][0]
|
| 57 |
+
diversity = max([
|
| 58 |
+
cosine_similarity([embeddings[idx]], [embeddings[sel]])[0][0]
|
| 59 |
+
for sel in selected
|
| 60 |
+
])
|
| 61 |
+
mmr_score = lambda_mult * relevance - (1 - lambda_mult) * diversity
|
| 62 |
+
mmr_scores.append((idx, mmr_score))
|
| 63 |
+
|
| 64 |
+
next_best = max(mmr_scores, key=lambda x: x[1])[0]
|
| 65 |
+
selected.append(next_best)
|
| 66 |
+
candidate_idxs.remove(next_best)
|
| 67 |
+
|
| 68 |
+
return [documents[i] for i in selected]
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# --- Summarization ---
|
| 72 |
+
def summarize_docs(word, timestamp, docs, llm, k):
|
| 73 |
+
if not docs:
|
| 74 |
+
return "No documents available for this word at this time.", [], 0
|
| 75 |
+
|
| 76 |
+
try:
|
| 77 |
+
model, index, embeddings, documents = build_mmr_index(docs)
|
| 78 |
+
mmr_docs = get_mmr_sample(model, index, embeddings, documents, query=word, k=k)
|
| 79 |
+
|
| 80 |
+
context_texts = "\n".join(f"- {doc.page_content}" for doc in mmr_docs)
|
| 81 |
+
|
| 82 |
+
prompt_template = ChatPromptTemplate.from_template(
|
| 83 |
+
"Given the following documents from {timestamp} containing the word '{word}', "
|
| 84 |
+
"identify the key themes or distinct discussion points that were prevalent during that time. "
|
| 85 |
+
"Do NOT describe each bullet in detail. Be concise. Each bullet should be a short phrase or sentence "
|
| 86 |
+
"capturing a unique, non-overlapping theme. Avoid any elaboration, examples, or justification.\n\n"
|
| 87 |
+
"Return no more than 5–7 bullets.\n\n"
|
| 88 |
+
"{context_texts}\n\nSummary:"
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
chain = prompt_template | llm
|
| 92 |
+
summary = chain.invoke({
|
| 93 |
+
"word": word,
|
| 94 |
+
"timestamp": timestamp,
|
| 95 |
+
"context_texts": context_texts
|
| 96 |
+
}).content.strip()
|
| 97 |
+
|
| 98 |
+
return summary, mmr_docs
|
| 99 |
+
|
| 100 |
+
except Exception as e:
|
| 101 |
+
return f"[Error summarizing: {e}]", [], 0
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def summarize_multiword_docs(words, timestamp, docs, llm, k):
|
| 105 |
+
if not docs:
|
| 106 |
+
return "No common documents available for these words at this time.", []
|
| 107 |
+
|
| 108 |
+
try:
|
| 109 |
+
model, index, embeddings, documents = build_mmr_index(docs)
|
| 110 |
+
query = " ".join(words)
|
| 111 |
+
mmr_docs = get_mmr_sample(model, index, embeddings, documents, query=query, k=k)
|
| 112 |
+
|
| 113 |
+
context_texts = "\n".join(f"- {doc.page_content}" for doc in mmr_docs)
|
| 114 |
+
|
| 115 |
+
prompt_template = ChatPromptTemplate.from_template(
|
| 116 |
+
"Given the following documents from {timestamp} that all mention the words: '{word_list}', "
|
| 117 |
+
"identify the key themes or distinct discussion points that were prevalent during that time. "
|
| 118 |
+
"Do NOT describe each bullet in detail. Be concise. Each bullet should be a short phrase or sentence "
|
| 119 |
+
"capturing a unique, non-overlapping theme. Avoid any elaboration, examples, or justification.\n\n"
|
| 120 |
+
"Return no more than 5–7 bullets.\n\n"
|
| 121 |
+
"{context_texts}\n\n"
|
| 122 |
+
"Concise Thematic Summary:"
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
chain = prompt_template | llm
|
| 126 |
+
summary = chain.invoke({
|
| 127 |
+
"word_list": ", ".join(words),
|
| 128 |
+
"timestamp": timestamp,
|
| 129 |
+
"context_texts": context_texts
|
| 130 |
+
}).content.strip()
|
| 131 |
+
|
| 132 |
+
return summary, mmr_docs
|
| 133 |
+
|
| 134 |
+
except Exception as e:
|
| 135 |
+
return f"[Error summarizing: {e}]", []
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
# --- Follow-up Question Handler (Improved) ---
|
| 139 |
+
def ask_multiturn_followup(history: list, question: str, llm, context_texts: str) -> str:
|
| 140 |
+
"""
|
| 141 |
+
Handles multi-turn follow-up questions based on a provided set of documents.
|
| 142 |
+
|
| 143 |
+
This function now REQUIRES context_texts to be provided, ensuring the LLM
|
| 144 |
+
is always grounded in the source documents for follow-up questions.
|
| 145 |
+
|
| 146 |
+
Args:
|
| 147 |
+
history (list): A list of dictionaries representing the conversation history
|
| 148 |
+
(e.g., [{"role": "user", "content": "..."}]).
|
| 149 |
+
question (str): The user's new follow-up question.
|
| 150 |
+
llm: The initialized language model instance.
|
| 151 |
+
context_texts (str): A single string containing all the numbered documents
|
| 152 |
+
for context.
|
| 153 |
+
|
| 154 |
+
Returns:
|
| 155 |
+
str: The AI's response to the follow-up question.
|
| 156 |
+
"""
|
| 157 |
+
try:
|
| 158 |
+
# 1. Reconstruct conversation memory from the history provided from the UI
|
| 159 |
+
memory = ConversationBufferMemory(return_messages=True)
|
| 160 |
+
for turn in history:
|
| 161 |
+
if turn["role"] == "user":
|
| 162 |
+
memory.chat_memory.add_user_message(turn["content"])
|
| 163 |
+
elif turn["role"] == "assistant":
|
| 164 |
+
memory.chat_memory.add_ai_message(turn["content"])
|
| 165 |
+
|
| 166 |
+
# 2. Define the system instruction that grounds the LLM
|
| 167 |
+
system_instruction = (
|
| 168 |
+
"You are an assistant answering questions strictly based on the provided sample documents below. "
|
| 169 |
+
"Your memory contains the previous turns of this conversation. "
|
| 170 |
+
"If the answer is not clearly available in the text, respond with: "
|
| 171 |
+
"'The information is not available in the documents provided.'\n\n"
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
# 3. Create the full prompt. No more conditional logic, as context is required.
|
| 175 |
+
# The `ConversationChain` will automatically use the memory, so we only need
|
| 176 |
+
# to provide the current input, which includes the grounding documents.
|
| 177 |
+
full_prompt = (
|
| 178 |
+
f"{system_instruction}"
|
| 179 |
+
f"--- DOCUMENTS ---\n{context_texts.strip()}\n\n"
|
| 180 |
+
f"--- QUESTION ---\n{question}"
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
# 4. Create and run the conversation chain
|
| 184 |
+
conversation = ConversationChain(llm=llm, memory=memory, verbose=False)
|
| 185 |
+
response = conversation.predict(input=full_prompt)
|
| 186 |
+
|
| 187 |
+
return response.strip()
|
| 188 |
+
|
| 189 |
+
except Exception as e:
|
| 190 |
+
# Good practice to log the full exception for easier debugging
|
| 191 |
+
print(f"[ERROR] in ask_multiturn_followup: {e}")
|
| 192 |
+
return f"[Error during multi-turn follow-up. Please check the logs.]"
|
backend/llm_utils/token_utils.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Literal
|
| 2 |
+
import tiktoken
|
| 3 |
+
import anthropic
|
| 4 |
+
from typing import List
|
| 5 |
+
|
| 6 |
+
# Gemini requires the Vertex AI SDK
|
| 7 |
+
try:
|
| 8 |
+
from vertexai.preview import tokenization as vertex_tokenization
|
| 9 |
+
except ImportError:
|
| 10 |
+
vertex_tokenization = None
|
| 11 |
+
|
| 12 |
+
# Mistral requires the SentencePiece tokenizer
|
| 13 |
+
try:
|
| 14 |
+
import sentencepiece as spm
|
| 15 |
+
except ImportError:
|
| 16 |
+
spm = None
|
| 17 |
+
|
| 18 |
+
# ---------------------------
|
| 19 |
+
# Individual Token Counters
|
| 20 |
+
# ---------------------------
|
| 21 |
+
|
| 22 |
+
def count_tokens_openai(text: str, model_name: str) -> int:
|
| 23 |
+
try:
|
| 24 |
+
encoding = tiktoken.encoding_for_model(model_name)
|
| 25 |
+
except KeyError:
|
| 26 |
+
encoding = tiktoken.get_encoding("cl100k_base") # fallback
|
| 27 |
+
return len(encoding.encode(text))
|
| 28 |
+
|
| 29 |
+
def count_tokens_anthropic(text: str, model_name: str) -> int:
|
| 30 |
+
try:
|
| 31 |
+
client = anthropic.Anthropic()
|
| 32 |
+
response = client.messages.count_tokens(
|
| 33 |
+
model=model_name,
|
| 34 |
+
messages=[{"role": "user", "content": text}]
|
| 35 |
+
)
|
| 36 |
+
return response['input_tokens']
|
| 37 |
+
except Exception as e:
|
| 38 |
+
raise RuntimeError(f"Anthropic token counting failed: {e}")
|
| 39 |
+
|
| 40 |
+
def count_tokens_gemini(text: str, model_name: str) -> int:
|
| 41 |
+
if vertex_tokenization is None:
|
| 42 |
+
raise ImportError("Please install vertexai: pip install google-cloud-aiplatform[tokenization]")
|
| 43 |
+
try:
|
| 44 |
+
tokenizer = vertex_tokenization.get_tokenizer_for_model("gemini-1.5-flash-002")
|
| 45 |
+
result = tokenizer.count_tokens(text)
|
| 46 |
+
return result.total_tokens
|
| 47 |
+
except Exception as e:
|
| 48 |
+
raise RuntimeError(f"Gemini token counting failed: {e}")
|
| 49 |
+
|
| 50 |
+
def count_tokens_mistral(text: str) -> int:
|
| 51 |
+
if spm is None:
|
| 52 |
+
raise ImportError("Please install sentencepiece: pip install sentencepiece")
|
| 53 |
+
try:
|
| 54 |
+
sp = spm.SentencePieceProcessor()
|
| 55 |
+
# IMPORTANT: You must provide the correct path to the tokenizer model file
|
| 56 |
+
sp.load("mistral_tokenizer.model")
|
| 57 |
+
tokens = sp.encode(text, out_type=str)
|
| 58 |
+
return len(tokens)
|
| 59 |
+
except Exception as e:
|
| 60 |
+
raise RuntimeError(f"Mistral token counting failed: {e}")
|
| 61 |
+
|
| 62 |
+
# ---------------------------
|
| 63 |
+
# Unified Token Counter
|
| 64 |
+
# ---------------------------
|
| 65 |
+
|
| 66 |
+
def count_tokens(text: str, model_name: str, provider: Literal["OpenAI", "Anthropic", "Gemini", "Mistral"]) -> int:
|
| 67 |
+
if provider == "OpenAI":
|
| 68 |
+
return count_tokens_openai(text, model_name)
|
| 69 |
+
elif provider == "Anthropic":
|
| 70 |
+
return count_tokens_anthropic(text, model_name)
|
| 71 |
+
elif provider == "Gemini":
|
| 72 |
+
return count_tokens_gemini(text, model_name)
|
| 73 |
+
elif provider == "Mistral":
|
| 74 |
+
return count_tokens_mistral(text)
|
| 75 |
+
else:
|
| 76 |
+
raise ValueError(f"Unsupported provider: {provider}")
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def get_token_limit_for_model(model_name, provider):
|
| 80 |
+
# Example values; update as needed for your providers
|
| 81 |
+
if provider == "openai":
|
| 82 |
+
if "gpt-4.1-nano" in model_name:
|
| 83 |
+
return 1047576 # Based on search results
|
| 84 |
+
elif "gpt-4o-mini" in model_name:
|
| 85 |
+
return 128000 # Based on search results
|
| 86 |
+
elif provider == "anthropic":
|
| 87 |
+
if "claude-3-opus" in model_name:
|
| 88 |
+
return 200000 # Based on search results
|
| 89 |
+
elif "claude-3-sonnet" in model_name:
|
| 90 |
+
return 200000 # Based on search results
|
| 91 |
+
elif provider == "gemini":
|
| 92 |
+
if "gemini-2.0-flash-lite" in model_name:
|
| 93 |
+
return 1048576 # Based on search results
|
| 94 |
+
elif "gemini-1.5-flash" in model_name:
|
| 95 |
+
return 1048576 # Based on search results
|
| 96 |
+
elif provider == "mistral":
|
| 97 |
+
if "mistral-small" in model_name:
|
| 98 |
+
return 32000 # Based on search results
|
| 99 |
+
elif "mistral-medium" in model_name:
|
| 100 |
+
return 32000 # Based on search results
|
| 101 |
+
return 8000 # default fallback
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def estimate_avg_tokens_per_doc(
|
| 105 |
+
docs: List[str],
|
| 106 |
+
model_name: str,
|
| 107 |
+
provider: Literal["OpenAI", "Anthropic", "Gemini", "Mistral"]
|
| 108 |
+
) -> float:
|
| 109 |
+
"""
|
| 110 |
+
Estimate the average number of tokens per document for the given model.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
docs (List[str]): List of documents.
|
| 114 |
+
model_name (str): Model name.
|
| 115 |
+
provider (Literal): LLM provider.
|
| 116 |
+
|
| 117 |
+
Returns:
|
| 118 |
+
float: Average number of tokens per document.
|
| 119 |
+
"""
|
| 120 |
+
if not docs:
|
| 121 |
+
return 0.0
|
| 122 |
+
token_counts = [count_tokens(doc, model_name, provider) for doc in docs]
|
| 123 |
+
return sum(token_counts) / len(token_counts)
|
| 124 |
+
|
| 125 |
+
def estimate_max_k(
|
| 126 |
+
docs: List[str],
|
| 127 |
+
model_name: str,
|
| 128 |
+
provider: Literal["OpenAI", "Anthropic", "Gemini", "Mistral"],
|
| 129 |
+
margin_ratio: float = 0.1,
|
| 130 |
+
) -> int:
|
| 131 |
+
"""
|
| 132 |
+
Estimate the maximum number of documents that can fit in the context window.
|
| 133 |
+
|
| 134 |
+
Returns:
|
| 135 |
+
int: Estimated K.
|
| 136 |
+
"""
|
| 137 |
+
if not docs:
|
| 138 |
+
return 0
|
| 139 |
+
|
| 140 |
+
max_tokens = get_token_limit_for_model(model_name, provider)
|
| 141 |
+
margin = int(max_tokens * margin_ratio)
|
| 142 |
+
available_tokens = max_tokens - margin
|
| 143 |
+
|
| 144 |
+
avg_tokens_per_doc = estimate_avg_tokens_per_doc(docs, model_name, provider)
|
| 145 |
+
if avg_tokens_per_doc == 0:
|
| 146 |
+
return 0
|
| 147 |
+
|
| 148 |
+
return min(len(docs), int(available_tokens // avg_tokens_per_doc))
|
| 149 |
+
|
| 150 |
+
def estimate_max_k_fast(docs, margin_ratio=0.1, max_tokens=8000, model_name="gpt-3.5-turbo"):
|
| 151 |
+
enc = tiktoken.encoding_for_model(model_name)
|
| 152 |
+
avg_len = sum(len(enc.encode(doc)) for doc in docs[:20]) / min(20, len(docs))
|
| 153 |
+
margin = int(max_tokens * margin_ratio)
|
| 154 |
+
available = max_tokens - margin
|
| 155 |
+
return min(len(docs), int(available // avg_len))
|
| 156 |
+
|
| 157 |
+
def estimate_k_max_from_word_stats(
|
| 158 |
+
avg_words_per_doc: float,
|
| 159 |
+
margin_ratio: float = 0.1,
|
| 160 |
+
avg_tokens_per_word: float = 1.3,
|
| 161 |
+
model_name=None,
|
| 162 |
+
provider=None
|
| 163 |
+
) -> int:
|
| 164 |
+
model_token_limit = get_token_limit_for_model(model_name, provider)
|
| 165 |
+
effective_limit = int(model_token_limit * (1 - margin_ratio))
|
| 166 |
+
est_tokens_per_doc = avg_words_per_doc * avg_tokens_per_word
|
| 167 |
+
return int(effective_limit // est_tokens_per_doc)
|
backend/models/CFDTM/CFDTM.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
from .ETC import ETC
|
| 7 |
+
from .UWE import UWE
|
| 8 |
+
from .Encoder import MLPEncoder
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class CFDTM(nn.Module):
|
| 12 |
+
'''
|
| 13 |
+
Modeling Dynamic Topics in Chain-Free Fashion by Evolution-Tracking Contrastive Learning and Unassociated Word Exclusion. ACL 2024 Findings
|
| 14 |
+
|
| 15 |
+
Xiaobao Wu, Xinshuai Dong, Liangming Pan, Thong Nguyen, Anh Tuan Luu.
|
| 16 |
+
'''
|
| 17 |
+
|
| 18 |
+
def __init__(self,
|
| 19 |
+
vocab_size,
|
| 20 |
+
train_time_wordfreq,
|
| 21 |
+
num_times,
|
| 22 |
+
pretrained_WE=None,
|
| 23 |
+
num_topics=50,
|
| 24 |
+
en_units=100,
|
| 25 |
+
temperature=0.1,
|
| 26 |
+
beta_temp=1.0,
|
| 27 |
+
weight_neg=1.0e+7,
|
| 28 |
+
weight_pos=1.0e+1,
|
| 29 |
+
weight_UWE=1.0e+3,
|
| 30 |
+
neg_topk=15,
|
| 31 |
+
dropout=0.,
|
| 32 |
+
embed_size=200
|
| 33 |
+
):
|
| 34 |
+
super().__init__()
|
| 35 |
+
|
| 36 |
+
self.num_topics = num_topics
|
| 37 |
+
self.beta_temp = beta_temp
|
| 38 |
+
self.train_time_wordfreq = train_time_wordfreq
|
| 39 |
+
self.encoder = MLPEncoder(vocab_size, num_topics, en_units, dropout)
|
| 40 |
+
|
| 41 |
+
self.a = 1 * np.ones((1, num_topics)).astype(np.float32)
|
| 42 |
+
self.mu2 = nn.Parameter(torch.as_tensor((np.log(self.a).T - np.mean(np.log(self.a), 1)).T))
|
| 43 |
+
self.var2 = nn.Parameter(torch.as_tensor((((1.0 / self.a) * (1 - (2.0 / num_topics))).T + (1.0 / (num_topics * num_topics)) * np.sum(1.0 / self.a, 1)).T))
|
| 44 |
+
|
| 45 |
+
self.mu2.requires_grad = False
|
| 46 |
+
self.var2.requires_grad = False
|
| 47 |
+
|
| 48 |
+
self.decoder_bn = nn.BatchNorm1d(vocab_size, affine=False)
|
| 49 |
+
|
| 50 |
+
if pretrained_WE is None:
|
| 51 |
+
self.word_embeddings = nn.init.trunc_normal_(torch.empty(vocab_size, embed_size), std=0.1)
|
| 52 |
+
self.word_embeddings = nn.Parameter(F.normalize(self.word_embeddings))
|
| 53 |
+
|
| 54 |
+
else:
|
| 55 |
+
self.word_embeddings = nn.Parameter(torch.from_numpy(pretrained_WE).float())
|
| 56 |
+
|
| 57 |
+
# topic_embeddings: TxKxD
|
| 58 |
+
self.topic_embeddings = nn.init.xavier_normal_(torch.zeros(num_topics, self.word_embeddings.shape[1])).repeat(num_times, 1, 1)
|
| 59 |
+
self.topic_embeddings = nn.Parameter(self.topic_embeddings)
|
| 60 |
+
|
| 61 |
+
self.ETC = ETC(num_times, temperature, weight_neg, weight_pos)
|
| 62 |
+
self.UWE = UWE(self.ETC, num_times, temperature, weight_UWE, neg_topk)
|
| 63 |
+
|
| 64 |
+
def get_beta(self):
|
| 65 |
+
dist = self.pairwise_euclidean_dist(F.normalize(self.topic_embeddings, dim=-1), F.normalize(self.word_embeddings, dim=-1))
|
| 66 |
+
beta = F.softmax(-dist / self.beta_temp, dim=1)
|
| 67 |
+
|
| 68 |
+
return beta
|
| 69 |
+
|
| 70 |
+
def pairwise_euclidean_dist(self, x, y):
|
| 71 |
+
cost = torch.sum(x ** 2, axis=-1, keepdim=True) + torch.sum(y ** 2, axis=-1) - 2 * torch.matmul(x, y.t())
|
| 72 |
+
return cost
|
| 73 |
+
|
| 74 |
+
def get_theta(self, x, times=None):
|
| 75 |
+
theta, mu, logvar = self.encoder(x)
|
| 76 |
+
if self.training:
|
| 77 |
+
return theta, mu, logvar
|
| 78 |
+
|
| 79 |
+
return theta
|
| 80 |
+
|
| 81 |
+
def get_KL(self, mu, logvar):
|
| 82 |
+
var = logvar.exp()
|
| 83 |
+
var_division = var / self.var2
|
| 84 |
+
diff = mu - self.mu2
|
| 85 |
+
diff_term = diff * diff / self.var2
|
| 86 |
+
logvar_division = self.var2.log() - logvar
|
| 87 |
+
KLD = 0.5 * ((var_division + diff_term + logvar_division).sum(axis=1) - self.num_topics)
|
| 88 |
+
|
| 89 |
+
return KLD.mean()
|
| 90 |
+
|
| 91 |
+
def get_NLL(self, theta, beta, x, recon_x=None):
|
| 92 |
+
if recon_x is None:
|
| 93 |
+
recon_x = self.decode(theta, beta)
|
| 94 |
+
recon_loss = -(x * recon_x.log()).sum(axis=1)
|
| 95 |
+
|
| 96 |
+
return recon_loss
|
| 97 |
+
|
| 98 |
+
def decode(self, theta, beta):
|
| 99 |
+
d1 = F.softmax(self.decoder_bn(torch.bmm(theta.unsqueeze(1), beta).squeeze(1)), dim=-1)
|
| 100 |
+
return d1
|
| 101 |
+
|
| 102 |
+
def forward(self, x, times):
|
| 103 |
+
loss = 0.
|
| 104 |
+
|
| 105 |
+
theta, mu, logvar = self.get_theta(x)
|
| 106 |
+
kl_theta = self.get_KL(mu, logvar)
|
| 107 |
+
|
| 108 |
+
loss += kl_theta
|
| 109 |
+
|
| 110 |
+
beta = self.get_beta()
|
| 111 |
+
time_index_beta = beta[times]
|
| 112 |
+
recon_x = self.decode(theta, time_index_beta)
|
| 113 |
+
NLL = self.get_NLL(theta, time_index_beta, x, recon_x)
|
| 114 |
+
NLL = NLL.mean()
|
| 115 |
+
loss += NLL
|
| 116 |
+
|
| 117 |
+
loss_ETC = self.ETC(self.topic_embeddings)
|
| 118 |
+
loss += loss_ETC
|
| 119 |
+
|
| 120 |
+
loss_UWE = self.UWE(self.train_time_wordfreq, beta, self.topic_embeddings, self.word_embeddings)
|
| 121 |
+
loss += loss_UWE
|
| 122 |
+
|
| 123 |
+
rst_dict = {
|
| 124 |
+
'loss': loss,
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
return rst_dict
|
backend/models/CFDTM/ETC.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class ETC(nn.Module):
|
| 7 |
+
def __init__(self, num_times, temperature, weight_neg, weight_pos):
|
| 8 |
+
super().__init__()
|
| 9 |
+
self.num_times = num_times
|
| 10 |
+
self.weight_neg = weight_neg
|
| 11 |
+
self.weight_pos = weight_pos
|
| 12 |
+
self.temperature = temperature
|
| 13 |
+
|
| 14 |
+
def forward(self, topic_embeddings):
|
| 15 |
+
loss = 0.
|
| 16 |
+
loss_neg = 0.
|
| 17 |
+
loss_pos = 0.
|
| 18 |
+
|
| 19 |
+
for t in range(self.num_times):
|
| 20 |
+
loss_neg += self.compute_loss(topic_embeddings[t], topic_embeddings[t], self.temperature, self_contrast=True)
|
| 21 |
+
|
| 22 |
+
for t in range(1, self.num_times):
|
| 23 |
+
loss_pos += self.compute_loss(topic_embeddings[t], topic_embeddings[t - 1].detach(), self.temperature, self_contrast=False, only_pos=True)
|
| 24 |
+
|
| 25 |
+
loss_neg *= (self.weight_neg / self.num_times)
|
| 26 |
+
loss_pos *= (self.weight_pos / (self.num_times - 1))
|
| 27 |
+
loss = loss_neg + loss_pos
|
| 28 |
+
|
| 29 |
+
return loss
|
| 30 |
+
|
| 31 |
+
def compute_loss(self, anchor_feature, contrast_feature, temperature, self_contrast=False, only_pos=False, all_neg=False):
|
| 32 |
+
# KxK
|
| 33 |
+
anchor_dot_contrast = torch.div(
|
| 34 |
+
torch.matmul(F.normalize(anchor_feature, dim=1), F.normalize(contrast_feature, dim=1).T),
|
| 35 |
+
temperature
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
|
| 39 |
+
logits = anchor_dot_contrast - logits_max.detach()
|
| 40 |
+
|
| 41 |
+
pos_mask = torch.eye(anchor_dot_contrast.shape[0]).to(anchor_dot_contrast.device)
|
| 42 |
+
|
| 43 |
+
if self_contrast is False:
|
| 44 |
+
if only_pos is False:
|
| 45 |
+
if all_neg is True:
|
| 46 |
+
exp_logits = torch.exp(logits)
|
| 47 |
+
sum_exp_logits = exp_logits.sum(1)
|
| 48 |
+
log_prob = -torch.log(sum_exp_logits + 1e-12)
|
| 49 |
+
|
| 50 |
+
mean_log_prob = -log_prob.sum() / (logits.shape[0] * logits.shape[1])
|
| 51 |
+
else:
|
| 52 |
+
# only pos
|
| 53 |
+
mean_log_prob = -(logits * pos_mask).sum() / pos_mask.sum()
|
| 54 |
+
else:
|
| 55 |
+
# self contrast: push away from each other in the same time slice.
|
| 56 |
+
exp_logits = torch.exp(logits) * (1 - pos_mask)
|
| 57 |
+
sum_exp_logits = exp_logits.sum(1)
|
| 58 |
+
log_prob = -torch.log(sum_exp_logits + 1e-12)
|
| 59 |
+
|
| 60 |
+
mean_log_prob = -log_prob.sum() / (1 - pos_mask).sum()
|
| 61 |
+
|
| 62 |
+
return mean_log_prob
|
backend/models/CFDTM/Encoder.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class MLPEncoder(nn.Module):
|
| 7 |
+
def __init__(self, vocab_size, num_topic, hidden_dim, dropout):
|
| 8 |
+
super().__init__()
|
| 9 |
+
|
| 10 |
+
self.fc11 = nn.Linear(vocab_size, hidden_dim)
|
| 11 |
+
self.fc12 = nn.Linear(hidden_dim, hidden_dim)
|
| 12 |
+
self.fc21 = nn.Linear(hidden_dim, num_topic)
|
| 13 |
+
self.fc22 = nn.Linear(hidden_dim, num_topic)
|
| 14 |
+
|
| 15 |
+
self.fc1_drop = nn.Dropout(dropout)
|
| 16 |
+
self.z_drop = nn.Dropout(dropout)
|
| 17 |
+
|
| 18 |
+
self.mean_bn = nn.BatchNorm1d(num_topic, affine=True)
|
| 19 |
+
self.mean_bn.weight.requires_grad = False
|
| 20 |
+
self.logvar_bn = nn.BatchNorm1d(num_topic, affine=True)
|
| 21 |
+
self.logvar_bn.weight.requires_grad = False
|
| 22 |
+
|
| 23 |
+
def reparameterize(self, mu, logvar):
|
| 24 |
+
if self.training:
|
| 25 |
+
std = torch.exp(0.5 * logvar)
|
| 26 |
+
eps = torch.randn_like(std)
|
| 27 |
+
return mu + (eps * std)
|
| 28 |
+
else:
|
| 29 |
+
return mu
|
| 30 |
+
|
| 31 |
+
def forward(self, x):
|
| 32 |
+
e1 = F.softplus(self.fc11(x))
|
| 33 |
+
e1 = F.softplus(self.fc12(e1))
|
| 34 |
+
e1 = self.fc1_drop(e1)
|
| 35 |
+
mu = self.mean_bn(self.fc21(e1))
|
| 36 |
+
logvar = self.logvar_bn(self.fc22(e1))
|
| 37 |
+
theta = self.reparameterize(mu, logvar)
|
| 38 |
+
theta = F.softmax(theta, dim=1)
|
| 39 |
+
theta = self.z_drop(theta)
|
| 40 |
+
return theta, mu, logvar
|
backend/models/CFDTM/UWE.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class UWE(nn.Module):
|
| 6 |
+
def __init__(self, ETC, num_times, temperature, weight_UWE, neg_topk):
|
| 7 |
+
super().__init__()
|
| 8 |
+
|
| 9 |
+
self.ETC = ETC
|
| 10 |
+
self.weight_UWE = weight_UWE
|
| 11 |
+
self.num_times = num_times
|
| 12 |
+
self.temperature = temperature
|
| 13 |
+
self.neg_topk = neg_topk
|
| 14 |
+
|
| 15 |
+
def forward(self, time_wordcount, beta, topic_embeddings, word_embeddings):
|
| 16 |
+
assert(self.num_times == time_wordcount.shape[0])
|
| 17 |
+
|
| 18 |
+
topk_indices = self.get_topk_indices(beta)
|
| 19 |
+
|
| 20 |
+
loss_UWE = 0.
|
| 21 |
+
cnt_valid_times = 0.
|
| 22 |
+
for t in range(self.num_times):
|
| 23 |
+
neg_idx = torch.where(time_wordcount[t] == 0)[0]
|
| 24 |
+
|
| 25 |
+
time_topk_indices = topk_indices[t]
|
| 26 |
+
neg_idx = list(set(neg_idx.cpu().tolist()).intersection(set(time_topk_indices.cpu().tolist())))
|
| 27 |
+
neg_idx = torch.tensor(neg_idx).long().to(time_wordcount.device)
|
| 28 |
+
|
| 29 |
+
if len(neg_idx) == 0:
|
| 30 |
+
continue
|
| 31 |
+
|
| 32 |
+
time_neg_WE = word_embeddings[neg_idx]
|
| 33 |
+
|
| 34 |
+
# topic_embeddings[t]: K x D
|
| 35 |
+
# word_embeddings[neg_idx]: |V_{neg}| x D
|
| 36 |
+
loss_UWE += self.ETC.compute_loss(topic_embeddings[t], time_neg_WE, temperature=self.temperature, all_neg=True)
|
| 37 |
+
cnt_valid_times += 1
|
| 38 |
+
|
| 39 |
+
if cnt_valid_times > 0:
|
| 40 |
+
loss_UWE *= (self.weight_UWE / cnt_valid_times)
|
| 41 |
+
|
| 42 |
+
return loss_UWE
|
| 43 |
+
|
| 44 |
+
def get_topk_indices(self, beta):
|
| 45 |
+
# topk_indices: T x K x neg_topk
|
| 46 |
+
topk_indices = torch.topk(beta, k=self.neg_topk, dim=-1).indices
|
| 47 |
+
topk_indices = torch.flatten(topk_indices, start_dim=1)
|
| 48 |
+
return topk_indices
|
backend/models/CFDTM/__init__.py
ADDED
|
File without changes
|
backend/models/CFDTM/__pycache__/CFDTM.cpython-39.pyc
ADDED
|
Binary file (4.01 kB). View file
|
|
|
backend/models/CFDTM/__pycache__/ETC.cpython-39.pyc
ADDED
|
Binary file (1.85 kB). View file
|
|
|
backend/models/CFDTM/__pycache__/Encoder.cpython-39.pyc
ADDED
|
Binary file (1.52 kB). View file
|
|
|
backend/models/CFDTM/__pycache__/UWE.cpython-39.pyc
ADDED
|
Binary file (1.56 kB). View file
|
|
|
backend/models/CFDTM/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (158 Bytes). View file
|
|
|
backend/models/DBERTopic_trainer.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from bertopic import BERTopic
|
| 3 |
+
from backend.datasets.utils import _utils
|
| 4 |
+
from backend.datasets.utils.logger import Logger
|
| 5 |
+
|
| 6 |
+
logger = Logger("WARNING")
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class DBERTopicTrainer:
|
| 10 |
+
def __init__(self,
|
| 11 |
+
dataset,
|
| 12 |
+
num_topics=20,
|
| 13 |
+
num_top_words=15,
|
| 14 |
+
nr_bins=20,
|
| 15 |
+
global_tuning=True,
|
| 16 |
+
evolution_tuning=True,
|
| 17 |
+
datetime_format=None,
|
| 18 |
+
verbose=False):
|
| 19 |
+
|
| 20 |
+
self.dataset = dataset
|
| 21 |
+
self.docs = dataset.raw_documents
|
| 22 |
+
self.num_topics=num_topics
|
| 23 |
+
# self.timestamps = dataset.train_times
|
| 24 |
+
self.vocab = dataset.vocab
|
| 25 |
+
self.num_top_words = num_top_words
|
| 26 |
+
# self.nr_bins = nr_bins
|
| 27 |
+
# self.global_tuning = global_tuning
|
| 28 |
+
# self.evolution_tuning = evolution_tuning
|
| 29 |
+
# self.datetime_format = datetime_format
|
| 30 |
+
self.verbose = verbose
|
| 31 |
+
|
| 32 |
+
if verbose:
|
| 33 |
+
logger.set_level("DEBUG")
|
| 34 |
+
else:
|
| 35 |
+
logger.set_level("WARNING")
|
| 36 |
+
|
| 37 |
+
def train(self, timestamps, datetime_format='%Y'):
|
| 38 |
+
logger.info("Fitting BERTopic...")
|
| 39 |
+
self.model = BERTopic(nr_topics=self.num_topics, verbose=self.verbose)
|
| 40 |
+
self.topics, _ = self.model.fit_transform(self.docs)
|
| 41 |
+
|
| 42 |
+
logger.info("Running topics_over_time...")
|
| 43 |
+
self.topics_over_time_df = self.model.topics_over_time(
|
| 44 |
+
docs=self.docs,
|
| 45 |
+
timestamps=timestamps,
|
| 46 |
+
nr_bins=len(set(timestamps)),
|
| 47 |
+
datetime_format=datetime_format
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
self.unique_timestamps = sorted(self.topics_over_time_df["Timestamp"].unique())
|
| 51 |
+
self.unique_topics = sorted(self.topics_over_time_df["Topic"].unique())
|
| 52 |
+
self.vocab = self.model.vectorizer_model.get_feature_names_out()
|
| 53 |
+
self.V = len(self.vocab)
|
| 54 |
+
self.K = len(self.unique_topics)
|
| 55 |
+
self.T = len(self.unique_timestamps)
|
| 56 |
+
|
| 57 |
+
def get_beta(self):
|
| 58 |
+
logger.info("Generating β matrix...")
|
| 59 |
+
|
| 60 |
+
beta = np.zeros((self.T, self.K, self.V))
|
| 61 |
+
topic_to_index = {topic: idx for idx, topic in enumerate(self.unique_topics)}
|
| 62 |
+
timestamp_to_index = {timestamp: idx for idx, timestamp in enumerate(self.unique_timestamps)}
|
| 63 |
+
|
| 64 |
+
# Extract topic representations at each time
|
| 65 |
+
for t_idx, timestamp in enumerate(self.unique_timestamps):
|
| 66 |
+
selection = self.topics_over_time_df[self.topics_over_time_df["Timestamp"] == timestamp]
|
| 67 |
+
for _, row in selection.iterrows():
|
| 68 |
+
topic = row["Topic"]
|
| 69 |
+
words = row["Words"].split(", ")
|
| 70 |
+
if topic not in topic_to_index:
|
| 71 |
+
continue
|
| 72 |
+
k = topic_to_index[topic]
|
| 73 |
+
for word in words:
|
| 74 |
+
if word in self.vocab:
|
| 75 |
+
v = np.where(self.vocab == word)[0][0]
|
| 76 |
+
beta[t_idx, k, v] += 1.0
|
| 77 |
+
|
| 78 |
+
# Normalize each β_tk to be a probability distribution
|
| 79 |
+
beta = beta / (beta.sum(axis=2, keepdims=True) + 1e-10)
|
| 80 |
+
return beta
|
| 81 |
+
|
| 82 |
+
def get_top_words(self, num_top_words=None):
|
| 83 |
+
if num_top_words is None:
|
| 84 |
+
num_top_words = self.num_top_words
|
| 85 |
+
beta = self.get_beta()
|
| 86 |
+
top_words_list = list()
|
| 87 |
+
for time in range(beta.shape[0]):
|
| 88 |
+
top_words = _utils.get_top_words(beta[time], self.vocab, num_top_words, self.verbose)
|
| 89 |
+
top_words_list.append(top_words)
|
| 90 |
+
return top_words_list
|
| 91 |
+
|
| 92 |
+
def get_theta(self):
|
| 93 |
+
# Not applicable for BERTopic; can return topic assignments or soft topic distributions if required
|
| 94 |
+
logger.warning("get_theta is not implemented for BERTopic.")
|
| 95 |
+
return None
|
| 96 |
+
|
| 97 |
+
def export_theta(self):
|
| 98 |
+
logger.warning("export_theta is not implemented for BERTopic.")
|
| 99 |
+
return None, None
|
backend/models/DETM.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import torch
|
| 3 |
+
from torch import nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class DETM(nn.Module):
|
| 8 |
+
"""
|
| 9 |
+
The Dynamic Embedded Topic Model. 2019
|
| 10 |
+
|
| 11 |
+
Adji B. Dieng, Francisco J. R. Ruiz, David M. Blei
|
| 12 |
+
"""
|
| 13 |
+
def __init__(self, vocab_size, num_times, train_size, train_time_wordfreq,
|
| 14 |
+
num_topics=50, train_WE=True, pretrained_WE=None, en_units=800,
|
| 15 |
+
eta_hidden_size=200, rho_size=300, enc_drop=0.0, eta_nlayers=3,
|
| 16 |
+
eta_dropout=0.0, delta=0.005, theta_act='relu', device='cpu'):
|
| 17 |
+
super().__init__()
|
| 18 |
+
|
| 19 |
+
## define hyperparameters
|
| 20 |
+
self.num_topics = num_topics
|
| 21 |
+
self.num_times = num_times
|
| 22 |
+
self.vocab_size = vocab_size
|
| 23 |
+
self.eta_hidden_size = eta_hidden_size
|
| 24 |
+
self.rho_size = rho_size
|
| 25 |
+
self.enc_drop = enc_drop
|
| 26 |
+
self.eta_nlayers = eta_nlayers
|
| 27 |
+
self.t_drop = nn.Dropout(enc_drop)
|
| 28 |
+
self.eta_dropout = eta_dropout
|
| 29 |
+
self.delta = delta
|
| 30 |
+
self.train_WE = train_WE
|
| 31 |
+
self.train_size = train_size
|
| 32 |
+
self.rnn_inp = train_time_wordfreq
|
| 33 |
+
self.device = device
|
| 34 |
+
|
| 35 |
+
self.theta_act = self.get_activation(theta_act)
|
| 36 |
+
|
| 37 |
+
## define the word embedding matrix \rho
|
| 38 |
+
if self.train_WE:
|
| 39 |
+
self.rho = nn.Linear(self.rho_size, self.vocab_size, bias=False)
|
| 40 |
+
else:
|
| 41 |
+
rho = nn.Embedding(pretrained_WE.size())
|
| 42 |
+
rho.weight.data = torch.from_numpy(pretrained_WE)
|
| 43 |
+
self.rho = rho.weight.data.clone().float().to(self.device)
|
| 44 |
+
|
| 45 |
+
## define the variational parameters for the topic embeddings over time (alpha) ... alpha is K x T x L
|
| 46 |
+
self.mu_q_alpha = nn.Parameter(torch.randn(self.num_topics, self.num_times, self.rho_size))
|
| 47 |
+
self.logsigma_q_alpha = nn.Parameter(torch.randn(self.num_topics, self.num_times, self.rho_size))
|
| 48 |
+
|
| 49 |
+
## define variational distribution for \theta_{1:D} via amortizartion... theta is K x D
|
| 50 |
+
self.q_theta = nn.Sequential(
|
| 51 |
+
nn.Linear(self.vocab_size + self.num_topics, en_units),
|
| 52 |
+
self.theta_act,
|
| 53 |
+
nn.Linear(en_units, en_units),
|
| 54 |
+
self.theta_act,
|
| 55 |
+
)
|
| 56 |
+
self.mu_q_theta = nn.Linear(en_units, self.num_topics, bias=True)
|
| 57 |
+
self.logsigma_q_theta = nn.Linear(en_units, self.num_topics, bias=True)
|
| 58 |
+
|
| 59 |
+
## define variational distribution for \eta via amortizartion... eta is K x T
|
| 60 |
+
self.q_eta_map = nn.Linear(self.vocab_size, self.eta_hidden_size)
|
| 61 |
+
self.q_eta = nn.LSTM(self.eta_hidden_size, self.eta_hidden_size, self.eta_nlayers, dropout=self.eta_dropout)
|
| 62 |
+
self.mu_q_eta = nn.Linear(self.eta_hidden_size + self.num_topics, self.num_topics, bias=True)
|
| 63 |
+
self.logsigma_q_eta = nn.Linear(self.eta_hidden_size + self.num_topics, self.num_topics, bias=True)
|
| 64 |
+
|
| 65 |
+
self.decoder_bn = nn.BatchNorm1d(vocab_size)
|
| 66 |
+
self.decoder_bn.weight.requires_grad = False
|
| 67 |
+
|
| 68 |
+
def get_activation(self, act):
|
| 69 |
+
activations = {
|
| 70 |
+
'tanh': nn.Tanh(),
|
| 71 |
+
'relu': nn.ReLU(),
|
| 72 |
+
'softplus': nn.Softplus(),
|
| 73 |
+
'rrelu': nn.RReLU(),
|
| 74 |
+
'leakyrelu': nn.LeakyReLU(),
|
| 75 |
+
'elu': nn.ELU(),
|
| 76 |
+
'selu': nn.SELU(),
|
| 77 |
+
'glu': nn.GLU(),
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
if act in activations:
|
| 81 |
+
act = activations[act]
|
| 82 |
+
else:
|
| 83 |
+
print('Defaulting to tanh activations...')
|
| 84 |
+
act = nn.Tanh()
|
| 85 |
+
return act
|
| 86 |
+
|
| 87 |
+
def reparameterize(self, mu, logvar):
|
| 88 |
+
"""Returns a sample from a Gaussian distribution via reparameterization.
|
| 89 |
+
"""
|
| 90 |
+
if self.training:
|
| 91 |
+
std = torch.exp(0.5 * logvar)
|
| 92 |
+
eps = torch.randn_like(std)
|
| 93 |
+
return eps.mul_(std).add_(mu)
|
| 94 |
+
else:
|
| 95 |
+
return mu
|
| 96 |
+
|
| 97 |
+
def get_kl(self, q_mu, q_logsigma, p_mu=None, p_logsigma=None):
|
| 98 |
+
"""Returns KL( N(q_mu, q_logsigma) || N(p_mu, p_logsigma) ).
|
| 99 |
+
"""
|
| 100 |
+
if p_mu is not None and p_logsigma is not None:
|
| 101 |
+
sigma_q_sq = torch.exp(q_logsigma)
|
| 102 |
+
sigma_p_sq = torch.exp(p_logsigma)
|
| 103 |
+
kl = ( sigma_q_sq + (q_mu - p_mu)**2 ) / ( sigma_p_sq + 1e-6 )
|
| 104 |
+
kl = kl - 1 + p_logsigma - q_logsigma
|
| 105 |
+
kl = 0.5 * torch.sum(kl, dim=-1)
|
| 106 |
+
else:
|
| 107 |
+
kl = -0.5 * torch.sum(1 + q_logsigma - q_mu.pow(2) - q_logsigma.exp(), dim=-1)
|
| 108 |
+
return kl
|
| 109 |
+
|
| 110 |
+
def get_alpha(self): ## mean field
|
| 111 |
+
alphas = torch.zeros(self.num_times, self.num_topics, self.rho_size).to(self.device)
|
| 112 |
+
kl_alpha = []
|
| 113 |
+
|
| 114 |
+
alphas[0] = self.reparameterize(self.mu_q_alpha[:, 0, :], self.logsigma_q_alpha[:, 0, :])
|
| 115 |
+
|
| 116 |
+
# TODO: why logsigma_p_0 is zero?
|
| 117 |
+
p_mu_0 = torch.zeros(self.num_topics, self.rho_size).to(self.device)
|
| 118 |
+
logsigma_p_0 = torch.zeros(self.num_topics, self.rho_size).to(self.device)
|
| 119 |
+
kl_0 = self.get_kl(self.mu_q_alpha[:, 0, :], self.logsigma_q_alpha[:, 0, :], p_mu_0, logsigma_p_0)
|
| 120 |
+
kl_alpha.append(kl_0)
|
| 121 |
+
for t in range(1, self.num_times):
|
| 122 |
+
alphas[t] = self.reparameterize(self.mu_q_alpha[:, t, :], self.logsigma_q_alpha[:, t, :])
|
| 123 |
+
|
| 124 |
+
p_mu_t = alphas[t - 1]
|
| 125 |
+
logsigma_p_t = torch.log(self.delta * torch.ones(self.num_topics, self.rho_size).to(self.device))
|
| 126 |
+
kl_t = self.get_kl(self.mu_q_alpha[:, t, :], self.logsigma_q_alpha[:, t, :], p_mu_t, logsigma_p_t)
|
| 127 |
+
kl_alpha.append(kl_t)
|
| 128 |
+
kl_alpha = torch.stack(kl_alpha).sum()
|
| 129 |
+
return alphas, kl_alpha.sum()
|
| 130 |
+
|
| 131 |
+
def get_eta(self, rnn_inp): ## structured amortized inference
|
| 132 |
+
inp = self.q_eta_map(rnn_inp).unsqueeze(1)
|
| 133 |
+
hidden = self.init_hidden()
|
| 134 |
+
output, _ = self.q_eta(inp, hidden)
|
| 135 |
+
output = output.squeeze()
|
| 136 |
+
|
| 137 |
+
etas = torch.zeros(self.num_times, self.num_topics).to(self.device)
|
| 138 |
+
kl_eta = []
|
| 139 |
+
|
| 140 |
+
inp_0 = torch.cat([output[0], torch.zeros(self.num_topics,).to(self.device)], dim=0)
|
| 141 |
+
mu_0 = self.mu_q_eta(inp_0)
|
| 142 |
+
logsigma_0 = self.logsigma_q_eta(inp_0)
|
| 143 |
+
etas[0] = self.reparameterize(mu_0, logsigma_0)
|
| 144 |
+
|
| 145 |
+
p_mu_0 = torch.zeros(self.num_topics,).to(self.device)
|
| 146 |
+
logsigma_p_0 = torch.zeros(self.num_topics,).to(self.device)
|
| 147 |
+
kl_0 = self.get_kl(mu_0, logsigma_0, p_mu_0, logsigma_p_0)
|
| 148 |
+
kl_eta.append(kl_0)
|
| 149 |
+
|
| 150 |
+
for t in range(1, self.num_times):
|
| 151 |
+
inp_t = torch.cat([output[t], etas[t-1]], dim=0)
|
| 152 |
+
mu_t = self.mu_q_eta(inp_t)
|
| 153 |
+
logsigma_t = self.logsigma_q_eta(inp_t)
|
| 154 |
+
etas[t] = self.reparameterize(mu_t, logsigma_t)
|
| 155 |
+
|
| 156 |
+
p_mu_t = etas[t-1]
|
| 157 |
+
logsigma_p_t = torch.log(self.delta * torch.ones(self.num_topics,).to(self.device))
|
| 158 |
+
kl_t = self.get_kl(mu_t, logsigma_t, p_mu_t, logsigma_p_t)
|
| 159 |
+
kl_eta.append(kl_t)
|
| 160 |
+
kl_eta = torch.stack(kl_eta).sum()
|
| 161 |
+
|
| 162 |
+
return etas, kl_eta
|
| 163 |
+
|
| 164 |
+
def get_theta(self, bows, times, eta=None): ## amortized inference
|
| 165 |
+
"""Returns the topic proportions.
|
| 166 |
+
"""
|
| 167 |
+
|
| 168 |
+
normalized_bows = bows / bows.sum(1, keepdims=True)
|
| 169 |
+
|
| 170 |
+
if eta is None and self.training is False:
|
| 171 |
+
eta, kl_eta = self.get_eta(self.rnn_inp)
|
| 172 |
+
|
| 173 |
+
eta_td = eta[times]
|
| 174 |
+
inp = torch.cat([normalized_bows, eta_td], dim=1)
|
| 175 |
+
q_theta = self.q_theta(inp)
|
| 176 |
+
if self.enc_drop > 0:
|
| 177 |
+
q_theta = self.t_drop(q_theta)
|
| 178 |
+
mu_theta = self.mu_q_theta(q_theta)
|
| 179 |
+
logsigma_theta = self.logsigma_q_theta(q_theta)
|
| 180 |
+
z = self.reparameterize(mu_theta, logsigma_theta)
|
| 181 |
+
theta = F.softmax(z, dim=-1)
|
| 182 |
+
kl_theta = self.get_kl(mu_theta, logsigma_theta, eta_td, torch.zeros(self.num_topics).to(self.device))
|
| 183 |
+
|
| 184 |
+
if self.training:
|
| 185 |
+
return theta, kl_theta
|
| 186 |
+
else:
|
| 187 |
+
return theta
|
| 188 |
+
|
| 189 |
+
@property
|
| 190 |
+
def word_embeddings(self):
|
| 191 |
+
return self.rho.weight
|
| 192 |
+
|
| 193 |
+
@property
|
| 194 |
+
def topic_embeddings(self):
|
| 195 |
+
alpha, _ = self.get_alpha()
|
| 196 |
+
return alpha
|
| 197 |
+
|
| 198 |
+
def get_beta(self, alpha=None):
|
| 199 |
+
"""Returns the topic matrix \beta of shape T x K x V
|
| 200 |
+
"""
|
| 201 |
+
|
| 202 |
+
if alpha is None and self.training is False:
|
| 203 |
+
alpha, kl_alpha = self.get_alpha()
|
| 204 |
+
|
| 205 |
+
if self.train_WE:
|
| 206 |
+
logit = self.rho(alpha.view(alpha.size(0) * alpha.size(1), self.rho_size))
|
| 207 |
+
else:
|
| 208 |
+
tmp = alpha.view(alpha.size(0) * alpha.size(1), self.rho_size)
|
| 209 |
+
logit = torch.mm(tmp, self.rho.permute(1, 0))
|
| 210 |
+
logit = logit.view(alpha.size(0), alpha.size(1), -1)
|
| 211 |
+
|
| 212 |
+
beta = F.softmax(logit, dim=-1)
|
| 213 |
+
|
| 214 |
+
return beta
|
| 215 |
+
|
| 216 |
+
def get_NLL(self, theta, beta, bows):
|
| 217 |
+
theta = theta.unsqueeze(1)
|
| 218 |
+
loglik = torch.bmm(theta, beta).squeeze(1)
|
| 219 |
+
loglik = torch.log(loglik + 1e-12)
|
| 220 |
+
nll = -loglik * bows
|
| 221 |
+
nll = nll.sum(-1)
|
| 222 |
+
return nll
|
| 223 |
+
|
| 224 |
+
def forward(self, bows, times):
|
| 225 |
+
bsz = bows.size(0)
|
| 226 |
+
coeff = self.train_size / bsz
|
| 227 |
+
eta, kl_eta = self.get_eta(self.rnn_inp)
|
| 228 |
+
theta, kl_theta = self.get_theta(bows, times, eta)
|
| 229 |
+
kl_theta = kl_theta.sum() * coeff
|
| 230 |
+
|
| 231 |
+
alpha, kl_alpha = self.get_alpha()
|
| 232 |
+
beta = self.get_beta(alpha)
|
| 233 |
+
|
| 234 |
+
beta = beta[times]
|
| 235 |
+
# beta = beta[times.type('torch.LongTensor')]
|
| 236 |
+
nll = self.get_NLL(theta, beta, bows)
|
| 237 |
+
nll = nll.sum() * coeff
|
| 238 |
+
|
| 239 |
+
loss = nll + kl_eta + kl_theta
|
| 240 |
+
|
| 241 |
+
rst_dict = {
|
| 242 |
+
'loss': loss,
|
| 243 |
+
'nll': nll,
|
| 244 |
+
'kl_eta': kl_eta,
|
| 245 |
+
'kl_theta': kl_theta
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
loss += kl_alpha
|
| 249 |
+
rst_dict['kl_alpha'] = kl_alpha
|
| 250 |
+
|
| 251 |
+
return rst_dict
|
| 252 |
+
|
| 253 |
+
def init_hidden(self):
|
| 254 |
+
"""Initializes the first hidden state of the RNN used as inference network for \\eta.
|
| 255 |
+
"""
|
| 256 |
+
weight = next(self.parameters())
|
| 257 |
+
nlayers = self.eta_nlayers
|
| 258 |
+
nhid = self.eta_hidden_size
|
| 259 |
+
return (weight.new_zeros(nlayers, 1, nhid), weight.new_zeros(nlayers, 1, nhid))
|
backend/models/DTM_trainer.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gensim
|
| 2 |
+
import numpy as np
|
| 3 |
+
from gensim.models import ldaseqmodel
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
import datetime
|
| 6 |
+
from multiprocessing.pool import Pool
|
| 7 |
+
from backend.datasets.utils import _utils
|
| 8 |
+
from backend.datasets.utils.logger import Logger
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
logger = Logger("WARNING")
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def work(arguments):
|
| 15 |
+
model, docs = arguments
|
| 16 |
+
theta_list = list()
|
| 17 |
+
for doc in tqdm(docs):
|
| 18 |
+
theta_list.append(model[doc])
|
| 19 |
+
return theta_list
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class DTMTrainer:
|
| 23 |
+
def __init__(self,
|
| 24 |
+
dataset,
|
| 25 |
+
num_topics=50,
|
| 26 |
+
num_top_words=15,
|
| 27 |
+
alphas=0.01,
|
| 28 |
+
chain_variance=0.005,
|
| 29 |
+
passes=10,
|
| 30 |
+
lda_inference_max_iter=25,
|
| 31 |
+
em_min_iter=6,
|
| 32 |
+
em_max_iter=20,
|
| 33 |
+
verbose=False
|
| 34 |
+
):
|
| 35 |
+
|
| 36 |
+
self.dataset = dataset
|
| 37 |
+
self.vocab_size = dataset.vocab_size
|
| 38 |
+
self.num_topics = num_topics
|
| 39 |
+
self.num_top_words = num_top_words
|
| 40 |
+
self.alphas = alphas
|
| 41 |
+
self.chain_variance = chain_variance
|
| 42 |
+
self.passes = passes
|
| 43 |
+
self.lda_inference_max_iter = lda_inference_max_iter
|
| 44 |
+
self.em_min_iter = em_min_iter
|
| 45 |
+
self.em_max_iter = em_max_iter
|
| 46 |
+
|
| 47 |
+
self.verbose = verbose
|
| 48 |
+
if verbose:
|
| 49 |
+
logger.set_level("DEBUG")
|
| 50 |
+
else:
|
| 51 |
+
logger.set_level("WARNING")
|
| 52 |
+
|
| 53 |
+
def train(self):
|
| 54 |
+
id2word = dict(zip(range(self.vocab_size), self.dataset.vocab))
|
| 55 |
+
train_bow = self.dataset.train_bow
|
| 56 |
+
train_times = self.dataset.train_times.astype('int32')
|
| 57 |
+
|
| 58 |
+
# order documents by time slices
|
| 59 |
+
self.doc_order_idx = np.argsort(train_times)
|
| 60 |
+
train_bow = train_bow[self.doc_order_idx]
|
| 61 |
+
time_slices = np.bincount(train_times)
|
| 62 |
+
|
| 63 |
+
corpus = gensim.matutils.Dense2Corpus(train_bow, documents_columns=False)
|
| 64 |
+
|
| 65 |
+
self.model = ldaseqmodel.LdaSeqModel(
|
| 66 |
+
corpus=corpus,
|
| 67 |
+
id2word=id2word,
|
| 68 |
+
time_slice=time_slices,
|
| 69 |
+
num_topics=self.num_topics,
|
| 70 |
+
alphas=self.alphas,
|
| 71 |
+
chain_variance=self.chain_variance,
|
| 72 |
+
em_min_iter=self.em_min_iter,
|
| 73 |
+
em_max_iter=self.em_max_iter,
|
| 74 |
+
lda_inference_max_iter=self.lda_inference_max_iter,
|
| 75 |
+
passes=self.passes
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
def test(self, bow):
|
| 79 |
+
# bow = dataset.bow.cpu().numpy()
|
| 80 |
+
# times = dataset.times.cpu().numpy()
|
| 81 |
+
corpus = gensim.matutils.Dense2Corpus(bow, documents_columns=False)
|
| 82 |
+
|
| 83 |
+
num_workers = 20
|
| 84 |
+
split_idx_list = np.array_split(np.arange(len(bow)), num_workers)
|
| 85 |
+
worker_size_list = [len(x) for x in split_idx_list]
|
| 86 |
+
|
| 87 |
+
worker_id = 0
|
| 88 |
+
docs_list = [list() for i in range(num_workers)]
|
| 89 |
+
for i, doc in enumerate(corpus):
|
| 90 |
+
docs_list[worker_id].append(doc)
|
| 91 |
+
if len(docs_list[worker_id]) >= worker_size_list[worker_id]:
|
| 92 |
+
worker_id += 1
|
| 93 |
+
|
| 94 |
+
args_list = list()
|
| 95 |
+
for docs in docs_list:
|
| 96 |
+
args_list.append([self.model, docs])
|
| 97 |
+
|
| 98 |
+
starttime = datetime.datetime.now()
|
| 99 |
+
|
| 100 |
+
pool = Pool(processes=num_workers)
|
| 101 |
+
results = pool.map(work, args_list)
|
| 102 |
+
|
| 103 |
+
pool.close()
|
| 104 |
+
pool.join()
|
| 105 |
+
|
| 106 |
+
theta_list = list()
|
| 107 |
+
for rst in results:
|
| 108 |
+
theta_list.extend(rst)
|
| 109 |
+
|
| 110 |
+
endtime = datetime.datetime.now()
|
| 111 |
+
|
| 112 |
+
print("DTM test time: {}s".format((endtime - starttime).seconds))
|
| 113 |
+
|
| 114 |
+
return np.asarray(theta_list)
|
| 115 |
+
|
| 116 |
+
def get_theta(self):
|
| 117 |
+
theta = self.model.gammas / self.model.gammas.sum(axis=1)[:, np.newaxis]
|
| 118 |
+
# NOTE: MUST transform gamma to original order.
|
| 119 |
+
return theta[np.argsort(self.doc_order_idx)]
|
| 120 |
+
|
| 121 |
+
def get_beta(self):
|
| 122 |
+
beta = list()
|
| 123 |
+
# K x V x T
|
| 124 |
+
for item in self.model.topic_chains:
|
| 125 |
+
# V x T
|
| 126 |
+
beta.append(item.e_log_prob)
|
| 127 |
+
|
| 128 |
+
# T x K x V
|
| 129 |
+
beta = np.transpose(np.asarray(beta), (2, 0, 1))
|
| 130 |
+
# use softmax
|
| 131 |
+
beta = np.exp(beta)
|
| 132 |
+
beta = beta / beta.sum(-1, keepdims=True)
|
| 133 |
+
return beta
|
| 134 |
+
|
| 135 |
+
def get_top_words(self, num_top_words=None):
|
| 136 |
+
if num_top_words is None:
|
| 137 |
+
num_top_words = self.num_top_words
|
| 138 |
+
beta = self.get_beta()
|
| 139 |
+
top_words_list = list()
|
| 140 |
+
for time in range(beta.shape[0]):
|
| 141 |
+
top_words = _utils.get_top_words(beta[time], self.dataset.vocab, num_top_words, self.verbose)
|
| 142 |
+
top_words_list.append(top_words)
|
| 143 |
+
return top_words_list
|
| 144 |
+
|
| 145 |
+
def export_theta(self):
|
| 146 |
+
train_theta = self.get_theta()
|
| 147 |
+
test_theta = self.test(self.dataset.test_bow)
|
| 148 |
+
return train_theta, test_theta
|
backend/models/dynamic_trainer.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from tqdm import tqdm
|
| 3 |
+
from collections import defaultdict
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch.optim.lr_scheduler import StepLR
|
| 7 |
+
from backend.datasets.utils import _utils
|
| 8 |
+
from backend.datasets.utils.logger import Logger
|
| 9 |
+
|
| 10 |
+
logger = Logger("WARNING")
|
| 11 |
+
|
| 12 |
+
class DynamicTrainer:
|
| 13 |
+
def __init__(self,
|
| 14 |
+
model,
|
| 15 |
+
dataset,
|
| 16 |
+
num_top_words=15,
|
| 17 |
+
epochs=200,
|
| 18 |
+
learning_rate=0.002,
|
| 19 |
+
batch_size=200,
|
| 20 |
+
lr_scheduler=None,
|
| 21 |
+
lr_step_size=125,
|
| 22 |
+
log_interval=5,
|
| 23 |
+
verbose=False
|
| 24 |
+
):
|
| 25 |
+
|
| 26 |
+
self.model = model
|
| 27 |
+
self.dataset = dataset
|
| 28 |
+
self.num_top_words = num_top_words
|
| 29 |
+
self.epochs = epochs
|
| 30 |
+
self.learning_rate = learning_rate
|
| 31 |
+
self.batch_size = batch_size
|
| 32 |
+
self.lr_scheduler = lr_scheduler
|
| 33 |
+
self.lr_step_size = lr_step_size
|
| 34 |
+
self.log_interval = log_interval
|
| 35 |
+
|
| 36 |
+
self.verbose = verbose
|
| 37 |
+
if verbose:
|
| 38 |
+
logger.set_level("DEBUG")
|
| 39 |
+
else:
|
| 40 |
+
logger.set_level("WARNING")
|
| 41 |
+
|
| 42 |
+
def make_optimizer(self,):
|
| 43 |
+
args_dict = {
|
| 44 |
+
'params': self.model.parameters(),
|
| 45 |
+
'lr': self.learning_rate,
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
optimizer = torch.optim.Adam(**args_dict)
|
| 49 |
+
return optimizer
|
| 50 |
+
|
| 51 |
+
def make_lr_scheduler(self, optimizer):
|
| 52 |
+
lr_scheduler = StepLR(optimizer, step_size=self.lr_step_size, gamma=0.5, verbose=False)
|
| 53 |
+
return lr_scheduler
|
| 54 |
+
|
| 55 |
+
def train(self):
|
| 56 |
+
optimizer = self.make_optimizer()
|
| 57 |
+
|
| 58 |
+
if self.lr_scheduler:
|
| 59 |
+
logger.info("using lr_scheduler")
|
| 60 |
+
lr_scheduler = self.make_lr_scheduler(optimizer)
|
| 61 |
+
|
| 62 |
+
data_size = len(self.dataset.train_dataloader.dataset)
|
| 63 |
+
|
| 64 |
+
for epoch in tqdm(range(1, self.epochs + 1)):
|
| 65 |
+
self.model.train()
|
| 66 |
+
loss_rst_dict = defaultdict(float)
|
| 67 |
+
|
| 68 |
+
for batch_data in self.dataset.train_dataloader:
|
| 69 |
+
|
| 70 |
+
rst_dict = self.model(batch_data['bow'], batch_data['times'])
|
| 71 |
+
batch_loss = rst_dict['loss']
|
| 72 |
+
|
| 73 |
+
optimizer.zero_grad()
|
| 74 |
+
batch_loss.backward()
|
| 75 |
+
optimizer.step()
|
| 76 |
+
|
| 77 |
+
for key in rst_dict:
|
| 78 |
+
loss_rst_dict[key] += rst_dict[key] * len(batch_data)
|
| 79 |
+
|
| 80 |
+
if self.lr_scheduler:
|
| 81 |
+
lr_scheduler.step()
|
| 82 |
+
|
| 83 |
+
if epoch % self.log_interval == 0:
|
| 84 |
+
output_log = f'Epoch: {epoch:03d}'
|
| 85 |
+
for key in loss_rst_dict:
|
| 86 |
+
output_log += f' {key}: {loss_rst_dict[key] / data_size :.3f}'
|
| 87 |
+
|
| 88 |
+
logger.info(output_log)
|
| 89 |
+
|
| 90 |
+
top_words = self.get_top_words()
|
| 91 |
+
train_theta = self.test(self.dataset.train_bow, self.dataset.train_times)
|
| 92 |
+
|
| 93 |
+
return top_words, train_theta
|
| 94 |
+
|
| 95 |
+
def test(self, bow, times):
|
| 96 |
+
data_size = bow.shape[0]
|
| 97 |
+
theta = list()
|
| 98 |
+
all_idx = torch.split(torch.arange(data_size), self.batch_size)
|
| 99 |
+
|
| 100 |
+
with torch.no_grad():
|
| 101 |
+
self.model.eval()
|
| 102 |
+
for idx in all_idx:
|
| 103 |
+
batch_theta = self.model.get_theta(bow[idx], times[idx])
|
| 104 |
+
theta.extend(batch_theta.cpu().tolist())
|
| 105 |
+
|
| 106 |
+
theta = np.asarray(theta)
|
| 107 |
+
return theta
|
| 108 |
+
|
| 109 |
+
def get_beta(self):
|
| 110 |
+
self.model.eval()
|
| 111 |
+
beta = self.model.get_beta().detach().cpu().numpy()
|
| 112 |
+
return beta
|
| 113 |
+
|
| 114 |
+
def get_top_words(self, num_top_words=None):
|
| 115 |
+
if num_top_words is None:
|
| 116 |
+
num_top_words = self.num_top_words
|
| 117 |
+
|
| 118 |
+
beta = self.get_beta()
|
| 119 |
+
top_words_list = list()
|
| 120 |
+
for time in range(beta.shape[0]):
|
| 121 |
+
if self.verbose:
|
| 122 |
+
print(f"======= Time: {time} =======")
|
| 123 |
+
top_words = _utils.get_top_words(beta[time], self.dataset.vocab, num_top_words, self.verbose)
|
| 124 |
+
top_words_list.append(top_words)
|
| 125 |
+
return top_words_list
|
| 126 |
+
|
| 127 |
+
def export_theta(self):
|
| 128 |
+
train_theta = self.test(self.dataset.train_bow, self.dataset.train_times)
|
| 129 |
+
test_theta = self.test(self.dataset.test_bow, self.dataset.test_times)
|
| 130 |
+
|
| 131 |
+
return train_theta, test_theta
|
| 132 |
+
|
| 133 |
+
def get_top_words_at_time(self, topic_id, time, top_n):
|
| 134 |
+
beta = self.get_beta() # shape: [T, K, V]
|
| 135 |
+
topic_beta = beta[time, topic_id, :]
|
| 136 |
+
top_indices = topic_beta.argsort()[-top_n:][::-1]
|
| 137 |
+
return [self.dataset.vocab[i] for i in top_indices]
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def get_topic_words_over_time(self, topic_id, top_n):
|
| 141 |
+
"""
|
| 142 |
+
Returns top_n words for the given topic_id over all time steps.
|
| 143 |
+
Output: List[List[str]], each inner list is the top_n words at a time step.
|
| 144 |
+
"""
|
| 145 |
+
beta = self.get_beta() # shape: [T, K, V]
|
| 146 |
+
T = beta.shape[0]
|
| 147 |
+
return [
|
| 148 |
+
self.get_top_words_at_time(topic_id=topic_id, time=t, top_n=top_n)
|
| 149 |
+
for t in range(T)
|
| 150 |
+
]
|
| 151 |
+
|
| 152 |
+
def get_all_topics_at_time(self, time, top_n):
|
| 153 |
+
"""
|
| 154 |
+
Returns top_n words for each topic at the given time step.
|
| 155 |
+
Output: List[List[str]], each inner list is the top_n words for a topic.
|
| 156 |
+
"""
|
| 157 |
+
beta = self.get_beta() # shape: [T, K, V]
|
| 158 |
+
K = beta.shape[1]
|
| 159 |
+
return [
|
| 160 |
+
self.get_top_words_at_time(topic_id=k, time=time, top_n=top_n)
|
| 161 |
+
for k in range(K)
|
| 162 |
+
]
|
| 163 |
+
|
| 164 |
+
def get_all_topics_over_time(self, top_n=10):
|
| 165 |
+
"""
|
| 166 |
+
Returns the top_n words for all topics over all time steps.
|
| 167 |
+
Output shape: List[List[List[str]]] = T x K x top_n
|
| 168 |
+
"""
|
| 169 |
+
beta = self.get_beta() # shape: [T, K, V]
|
| 170 |
+
T, K, _ = beta.shape
|
| 171 |
+
return [
|
| 172 |
+
[
|
| 173 |
+
self.get_top_words_at_time(topic_id=k, time=t, top_n=top_n)
|
| 174 |
+
for k in range(K)
|
| 175 |
+
]
|
| 176 |
+
for t in range(T)
|
| 177 |
+
]
|
data/ACL_Anthology/CFDTM/beta.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:34984bfb432a10733161a9dfed834a9ef4f366a28a6cb2ecd6e8351997f1599a
|
| 3 |
+
size 16645248
|
data/ACL_Anthology/DETM/beta.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3c6eefa9b6aaea4c694736d09ad9e517446f09929c01889e26633300e5eff166
|
| 3 |
+
size 41612928
|
data/ACL_Anthology/DTM/beta.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:14c296a2e3fb49f9d0b66262907d64f7d181408768e43138d57c262ea6a11318
|
| 3 |
+
size 33290368
|
data/ACL_Anthology/DTM/topic_label_cache.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ea9f3c508ede82967cdf02050d7383d58dd9d269a7f661ae1462a95cbac3331e
|
| 3 |
+
size 2089
|
data/ACL_Anthology/docs.jsonl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a004dd095b9a4f29fdccb5144d50d3dacc7985af443a8de434005b7b8401f9b7
|
| 3 |
+
size 67395059
|
data/ACL_Anthology/inverted_index.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:60e7ee888abb2fd025b11415a7ead6780d41c5f890cc25ba453615906f10b8d7
|
| 3 |
+
size 30865281
|
data/ACL_Anthology/processed/lemma_to_forms.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:00ea8855f9ced2ca3d785ce5926ced29b35e0779cd6b3166edfd5c5a5c1beccb
|
| 3 |
+
size 4370995
|
data/ACL_Anthology/processed/length_stats.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5cc985e5a1ce565ca4179d343ade1526daab463520f6317122953da83d368306
|
| 3 |
+
size 133
|
data/ACL_Anthology/processed/time2id.txt
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"2010": 0,
|
| 3 |
+
"2011": 1,
|
| 4 |
+
"2012": 2,
|
| 5 |
+
"2013": 3,
|
| 6 |
+
"2014": 4,
|
| 7 |
+
"2015": 5,
|
| 8 |
+
"2016": 6,
|
| 9 |
+
"2017": 7,
|
| 10 |
+
"2018": 8,
|
| 11 |
+
"2019": 9,
|
| 12 |
+
"2020": 10,
|
| 13 |
+
"2021": 11,
|
| 14 |
+
"2022": 12,
|
| 15 |
+
"2023": 13,
|
| 16 |
+
"2024": 14,
|
| 17 |
+
"2025": 15
|
| 18 |
+
}
|
data/ACL_Anthology/processed/vocab.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|