Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import os | |
| from datasets import load_dataset | |
| from bunkatopics import Bunka | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain_community.llms import HuggingFaceHub | |
| # Streamlit app | |
| st.title("Bunka Map 🗺️") | |
| # Input parameters | |
| dataset_id = st.text_input("Dataset ID", "bunkalab/medium-sample-technology") | |
| language = st.text_input("Language", "english") | |
| text_field = st.text_input("Text Field", "title") | |
| embedder_model = st.text_input("Embedder Model", "sentence-transformers/distiluse-base-multilingual-cased-v2") | |
| sample_size = st.number_input("Sample Size", min_value=100, max_value=10000, value=1000) | |
| n_clusters = st.number_input("Number of Clusters", min_value=5, max_value=50, value=15) | |
| llm_model = st.text_input("LLM Model", "mistralai/Mistral-7B-Instruct-v0.1") | |
| # Hugging Face API token input | |
| hf_token = st.text_input("Hugging Face API Token", type="password") | |
| if st.button("Generate Bunka Map"): | |
| # Load dataset and sample | |
| def load_data(dataset_id, text_field, sample_size): | |
| dataset = load_dataset(dataset_id, streaming=True) | |
| docs_sample = [] | |
| for i, example in enumerate(dataset["train"]): | |
| if i >= sample_size: | |
| break | |
| docs_sample.append(example[text_field]) | |
| return docs_sample | |
| docs_sample = load_data(dataset_id, text_field, sample_size) | |
| # Initialize embedding model and Bunka | |
| embedding_model = HuggingFaceEmbeddings(model_name=embedder_model) | |
| bunka = Bunka(embedding_model=embedding_model, language=language) | |
| # Fit Bunka to the text data | |
| bunka.fit(docs_sample) | |
| # Generate topics | |
| df_topics = bunka.get_topics(n_clusters=n_clusters, name_length=5, min_count_terms=2) | |
| # Visualize topics | |
| st.plotly_chart(bunka.visualize_topics(width=800, height=800, colorscale='Portland', density=True, label_size_ratio=60, convex_hull=True)) | |
| # Clean labels using LLM | |
| if hf_token: | |
| os.environ["HUGGINGFACEHUB_API_TOKEN"] = hf_token | |
| llm = HuggingFaceHub(repo_id=llm_model, huggingfacehub_api_token=hf_token) | |
| bunka.get_clean_topic_name(llm=llm, language=language) | |
| st.plotly_chart(bunka.visualize_topics(width=800, height=800, colorscale='Portland', density=True, label_size_ratio=60, convex_hull=True)) | |
| else: | |
| st.warning("Please provide a Hugging Face API token to clean labels using LLM.") | |
| # Manual topic cleaning | |
| st.subheader("Manually Clean Topics") | |
| cleaned_topics = {} | |
| for topic, keywords in bunka.topics_.items(): | |
| cleaned_topic = st.text_input(f"Topic {topic}", ", ".join(keywords)) | |
| cleaned_topics[topic] = cleaned_topic.split(", ") | |
| if st.button("Update Topics"): | |
| bunka.topics_ = cleaned_topics | |
| st.plotly_chart(bunka.visualize_topics(width=800, height=800, colorscale='Portland', density=True, label_size_ratio=60, convex_hull=True)) | |
| # Remove unwanted topics | |
| st.subheader("Remove Unwanted Topics") | |
| topics_to_remove = st.multiselect("Select topics to remove", list(bunka.topics_.keys())) | |
| if st.button("Remove Topics"): | |
| bunka.clean_data_by_topics(topics_to_remove) | |
| st.plotly_chart(bunka.visualize_topics(width=800, height=800, colorscale='Portland', density=True, label_size_ratio=60, convex_hull=True)) | |
| # Save dataset | |
| if st.button("Save Cleaned Dataset"): | |
| name = dataset_id.replace('/', '_') + '_cleaned.csv' | |
| bunka.df_cleaned_.to_csv(name) | |
| st.success(f"Dataset saved as {name}") |