import matplotlib.pyplot as plt import nltk import numpy as np import pandas as pd import seaborn as sns import streamlit as st from nltk.tokenize import word_tokenize import os # Use a writable temporary directory nltk_data_dir = "/tmp/nltk_data" # Create the directory if it doesn't exist os.makedirs(nltk_data_dir, exist_ok=True) # Add it to NLTK's data path nltk.data.path.append(nltk_data_dir) # Download 'punkt' only if not already downloaded try: nltk.data.find("tokenizers/punkt") except LookupError: nltk.download("punkt", download_dir=nltk_data_dir) nltk.download("punkt_tab", download_dir=nltk_data_dir) st.title("📊 Bayesian Token Co-occurrence Simulator") # User input user_input = st.text_area( "✍️ Enter your training sentences (one per line):", """ fido loves the red ball timmy and fido go to the park fido and timmy love to play the red ball is timmy's favorite toy """, ) sentences = user_input.strip().split("\n") tokenized = [word_tokenize(s.lower()) for s in sentences if s.strip()] vocab = sorted(set(word for sentence in tokenized for word in sentence)) token2idx = {word: i for i, word in enumerate(vocab)} idx2token = {i: word for word, i in token2idx.items()} # Co-occurrence matrix window_size = 2 matrix = np.zeros((len(vocab), len(vocab))) for sentence in tokenized: for i, word in enumerate(sentence): for j in range( max(0, i - window_size), min(len(sentence), i + window_size + 1) ): if i != j: matrix[token2idx[word]][token2idx[sentence[j]]] += 1 alpha = st.slider("🔧 Set Bayesian Prior (α smoothing)", 0.0, 2.0, 0.1) posterior = matrix + alpha df = pd.DataFrame(posterior, index=vocab, columns=vocab) st.subheader("📈 Co-occurrence Heatmap") fig, ax = plt.subplots(figsize=(10, 8)) sns.heatmap(df, annot=True, cmap="Blues", fmt=".1f", ax=ax) st.pyplot(fig) # Next-token prediction selected_word = st.selectbox("🔮 Predict next token after:", vocab) row = posterior[token2idx[selected_word]] probs = row / row.sum() prediction = np.random.choice(vocab, p=probs) st.markdown(f"**Predicted next token:** `{prediction}`")