|
|
from PIL import Image |
|
|
import gradio as gr |
|
|
import re |
|
|
import pandas as pd |
|
|
import joblib |
|
|
import datetime |
|
|
import matplotlib.pyplot as plt |
|
|
from io import BytesIO |
|
|
from nltk.tokenize import TreebankWordTokenizer |
|
|
from nltk.stem import WordNetLemmatizer |
|
|
from sklearn.feature_extraction.text import ENGLISH_STOP_WORDS |
|
|
import os |
|
|
import time |
|
|
import zipfile |
|
|
|
|
|
|
|
|
lda = joblib.load("lda_model.joblib") |
|
|
vectorizer = joblib.load("vectorizer.joblib") |
|
|
auto_labels = joblib.load("topic_labels.joblib") |
|
|
|
|
|
|
|
|
topic_summaries = { |
|
|
"Politics & Gun Rights": "Discussions about government policies, laws, gun control, and rights.", |
|
|
"Computing & Hardware": "Technical issues and terms related to computer hardware and drivers.", |
|
|
"Programming & Software": "Programming terms, file handling, software output.", |
|
|
"Sports & Games": "Topics related to teams, players, seasons, and matches.", |
|
|
"Health & Medicine": "Diseases, treatment, healthcare, and medical facilities.", |
|
|
"Religion & Philosophy": "Talks involving faith, belief systems, philosophical views.", |
|
|
"Space & NASA": "Space exploration, NASA missions, satellites, and astronomy.", |
|
|
"Cryptography & Security": "Discussions on encryption, digital security, and data protection.", |
|
|
"Internet & Networking": "Terms around internet use, FTP, web versions, and networks.", |
|
|
"Middle East Politics & Conflicts": "Topics involving Israel, Armenia, conflict regions." |
|
|
} |
|
|
|
|
|
|
|
|
tokenizer = TreebankWordTokenizer() |
|
|
lemmatizer = WordNetLemmatizer() |
|
|
|
|
|
|
|
|
|
|
|
def preprocess(text): |
|
|
text = re.sub(r'\W+', ' ', text.lower()) |
|
|
tokens = tokenizer.tokenize(text) |
|
|
tokens = [lemmatizer.lemmatize(w) for w in tokens if w not in ENGLISH_STOP_WORDS and len(w) > 2 and w.isalpha()] |
|
|
return ' '.join(tokens) |
|
|
|
|
|
def get_topic_keywords(model, vectorizer, topic_idx, top_n=10): |
|
|
feature_names = vectorizer.get_feature_names_out() |
|
|
topic = model.components_[topic_idx] |
|
|
top_indices = topic.argsort()[:-top_n - 1:-1] |
|
|
return [feature_names[i] for i in top_indices] |
|
|
|
|
|
def plot_topic_distribution(distribution, labels): |
|
|
plt.figure(figsize=(8, 4)) |
|
|
plt.bar(range(len(distribution)), distribution, tick_label=labels) |
|
|
plt.xticks(rotation=45, ha="right") |
|
|
plt.ylabel("Probability") |
|
|
plt.title("Topic Distribution") |
|
|
plt.tight_layout() |
|
|
buf = BytesIO() |
|
|
plt.savefig(buf, format="png") |
|
|
plt.close() |
|
|
buf.seek(0) |
|
|
return Image.open(buf) |
|
|
|
|
|
def save_prediction_file(text): |
|
|
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
filename = f"lda_prediction_{timestamp}.txt" |
|
|
with open(filename, "w", encoding="utf-8") as f: |
|
|
f.write(text) |
|
|
return filename |
|
|
|
|
|
def cleanup_old_predictions(directory=".", extension=".txt", max_age_minutes=10): |
|
|
now = time.time() |
|
|
max_age = max_age_minutes * 60 |
|
|
for fname in os.listdir(directory): |
|
|
if fname.endswith(extension) and fname.startswith("lda_prediction_"): |
|
|
full_path = os.path.join(directory, fname) |
|
|
if os.path.isfile(full_path) and (now - os.path.getmtime(full_path)) > max_age: |
|
|
try: |
|
|
os.remove(full_path) |
|
|
except Exception as e: |
|
|
print(f"Failed to delete {fname}: {e}") |
|
|
|
|
|
def download_log(): |
|
|
zip_filename = "lda_predictions_log.zip" |
|
|
with zipfile.ZipFile(zip_filename, "w", zipfile.ZIP_DEFLATED) as zipf: |
|
|
zipf.write("lda_predictions_log.csv") |
|
|
return zip_filename |
|
|
|
|
|
def save_feedback(text, feedback): |
|
|
timestamp = datetime.datetime.now().isoformat() |
|
|
log_entry = pd.DataFrame([{ |
|
|
"timestamp": timestamp, |
|
|
"feedback": feedback, |
|
|
"text_excerpt": text[:300].replace('\n', ' ') + "..." |
|
|
}]) |
|
|
feedback_log = "lda_feedback_log.csv" |
|
|
log_entry.to_csv(feedback_log, mode='a', header=not os.path.exists(feedback_log), index=False) |
|
|
return " Feedback recorded. Thank you!" |
|
|
|
|
|
|
|
|
|
|
|
def predict_topic(text_input, file_input): |
|
|
cleanup_old_predictions() |
|
|
|
|
|
if file_input is not None: |
|
|
text = file_input.read().decode("utf-8") |
|
|
elif text_input.strip(): |
|
|
text = text_input |
|
|
else: |
|
|
return "Please provide input", None, None |
|
|
|
|
|
cleaned = preprocess(text) |
|
|
bow = vectorizer.transform([cleaned]) |
|
|
topic_distribution = lda.transform(bow)[0] |
|
|
dominant_topic = topic_distribution.argmax() |
|
|
label = auto_labels.get(dominant_topic, f"Topic {dominant_topic+1}") |
|
|
top_words = get_topic_keywords(lda, vectorizer, dominant_topic) |
|
|
summary = topic_summaries.get(label, "No summary available.") |
|
|
|
|
|
|
|
|
confidence_threshold = 0.4 |
|
|
if topic_distribution[dominant_topic] < confidence_threshold: |
|
|
label += " ( Low confidence)" |
|
|
summary = " The model is uncertain. Try providing more context or a longer input." |
|
|
|
|
|
|
|
|
timestamp = datetime.datetime.now().isoformat() |
|
|
log_entry = pd.DataFrame([{ |
|
|
"timestamp": timestamp, |
|
|
"predicted_topic": label, |
|
|
"dominant_topic_index": dominant_topic, |
|
|
"top_words": ", ".join(top_words), |
|
|
"text_excerpt": text[:300].replace('\n', ' ') + "..." |
|
|
}]) |
|
|
log_path = "lda_predictions_log.csv" |
|
|
log_entry.to_csv(log_path, mode='a', header=not os.path.exists(log_path), index=False) |
|
|
|
|
|
chart = plot_topic_distribution(topic_distribution, [auto_labels.get(i, f"Topic {i+1}") for i in range(len(topic_distribution))]) |
|
|
|
|
|
result = f" **Predicted Topic:** {label}\n\n" |
|
|
result += f" **Summary:** {summary}\n\n" |
|
|
result += f" **Top Words:** {', '.join(top_words)}\n\n" |
|
|
result += " **Topic Distribution:**\n" |
|
|
for idx, prob in enumerate(topic_distribution): |
|
|
tlabel = auto_labels.get(idx, f"Topic {idx+1}") |
|
|
result += f"{tlabel}: {prob:.3f}\n" |
|
|
|
|
|
prediction_file = save_prediction_file(result) |
|
|
return result, chart, prediction_file |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("## Topic Modeling with LDA") |
|
|
gr.Markdown("Upload a `.txt` file or paste in text. See predicted topic, keywords, and a chart.") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
text_input = gr.Textbox(lines=10, label=" Paste Text") |
|
|
file_input = gr.File(label=" Or Upload a .txt File", file_types=[".txt"]) |
|
|
predict_btn = gr.Button(" Predict Topic") |
|
|
download_btn = gr.Button("⬇ Download All Logs") |
|
|
|
|
|
feedback_input = gr.Radio( |
|
|
choices=["Accurate", " Inaccurate", "Unclear"], |
|
|
label=" Was this prediction useful?", |
|
|
interactive=True |
|
|
) |
|
|
feedback_btn = gr.Button("Submit Feedback") |
|
|
feedback_output = gr.Textbox(visible=False) |
|
|
|
|
|
with gr.Column(): |
|
|
output_text = gr.Textbox(label=" Prediction Result") |
|
|
output_chart = gr.Image(type="pil", label=" Topic Distribution") |
|
|
download_prediction = gr.File(label="⬇ Download This Prediction") |
|
|
|
|
|
predict_btn.click( |
|
|
fn=predict_topic, |
|
|
inputs=[text_input, file_input], |
|
|
outputs=[output_text, output_chart, download_prediction] |
|
|
) |
|
|
|
|
|
download_btn.click(fn=download_log, outputs=[gr.File()]) |
|
|
|
|
|
feedback_btn.click( |
|
|
fn=save_feedback, |
|
|
inputs=[text_input, feedback_input], |
|
|
outputs=[feedback_output] |
|
|
) |
|
|
|
|
|
demo.launch() |
|
|
|