Spaces:
Running
Running
| from hashlib import sha256 | |
| import json | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_core.output_parsers import StrOutputParser | |
| from typing import Optional | |
| import os | |
| #get_top_words_at_time | |
| from backend.inference.process_beta import get_top_words_at_time | |
| def label_topic_temporal(word_trajectory_str: str, llm, cache_path: Optional[str] = None) -> str: | |
| """ | |
| Label a dynamic topic by providing the LLM with the top words over time. | |
| Args: | |
| word_trajectory_str (str): Formatted keyword evolution string. | |
| llm: LangChain-compatible LLM instance. | |
| cache_path (Optional[str]): Path to the cache file (JSON). | |
| Returns: | |
| str: Short label for the topic. | |
| """ | |
| topic_key = sha256(word_trajectory_str.encode()).hexdigest() | |
| # Load cache | |
| if cache_path is not None and os.path.exists(cache_path): | |
| with open(cache_path, "r") as f: | |
| label_cache = json.load(f) | |
| else: | |
| label_cache = {} | |
| # Return cached result | |
| if topic_key in label_cache: | |
| return label_cache[topic_key] | |
| # Prompt template | |
| prompt = ChatPromptTemplate.from_template( | |
| "You are an expert in topic modeling and temporal data analysis. " | |
| "Given the top words for a topic across multiple time points, your task is to return a short, specific, descriptive topic label. " | |
| "Avoid vague, generic, or overly broad labels. Focus on consistent themes in the top words over time. " | |
| "Use concise noun phrases, 2โ5 words max. Do NOT include any explanation, justification, or extra output.\n\n" | |
| "Top words over time:\n{trajectory}\n\n" | |
| "Return ONLY the label (no quotes, no extra text):" | |
| ) | |
| chain = prompt | llm | StrOutputParser() | |
| try: | |
| label = chain.invoke({"trajectory": word_trajectory_str}).strip() | |
| except Exception as e: | |
| label = "Unknown Topic" | |
| print(f"[Labeling Error] {e}") | |
| # Update cache and save | |
| label_cache[topic_key] = label | |
| if cache_path is not None: | |
| os.makedirs(os.path.dirname(cache_path), exist_ok=True) | |
| with open(cache_path, "w") as f: | |
| json.dump(label_cache, f, indent=2) | |
| return label | |
| def get_topic_labels(beta, vocab, time_labels, llm, cache_path): | |
| topic_labels = {} | |
| for topic_id in range(beta.shape[1]): | |
| word_trajectory_str = "\n".join([ | |
| f"{time_labels[t]}: {', '.join(get_top_words_at_time(beta, vocab, topic_id, t, top_n=10))}" | |
| for t in range(beta.shape[0]) | |
| ]) | |
| label = label_topic_temporal(word_trajectory_str, llm=llm, cache_path=cache_path) | |
| topic_labels[topic_id] = label | |
| return topic_labels |