File size: 4,260 Bytes
5e8ef8a
c67c5df
43d807a
 
 
 
 
 
 
 
8fcd770
 
 
dca4881
 
8fcd770
 
 
 
43d807a
e20be02
 
 
 
 
 
 
 
43d807a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c67c5df
43d807a
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import os
import streamlit as st
import pandas as pd
import joblib
import numpy as np
import string
import nltk
from nltk.corpus import stopwords as stp
from nltk import pos_tag, word_tokenize as w, sent_tokenize as s
from nltk.stem import WordNetLemmatizer as wl
NLTK_DATA_PATH = "/app/nltk_data"
os.makedirs(NLTK_DATA_PATH, exist_ok=True)
os.environ["NLTK_DATA"] = NLTK_DATA_PATH
nltk.download('punkt_tab', quiet=True)
nltk.download('punkt_tab', download_dir=NLTK_DATA_PATH, quiet=True)
nltk.download('punkt', download_dir=NLTK_DATA_PATH, quiet=True)
nltk.download('averaged_perceptron_tagger', download_dir=NLTK_DATA_PATH, quiet=True)
nltk.download('wordnet', download_dir=NLTK_DATA_PATH, quiet=True)
nltk.download('stopwords', download_dir=NLTK_DATA_PATH, quiet=True)
# Download necessary NLTK data
#nltk.download('punkt', quiet=True)
#nltk.download('averaged_perceptron_tagger', quiet=True)
#nltk.download('wordnet', quiet=True)
#nltk.download('stopwords', quiet=True)
nltk.download('punkt', download_dir=NLTK_DATA_PATH, quiet=True)
nltk.download('averaged_perceptron_tagger_eng', download_dir=NLTK_DATA_PATH, quiet=True)
nltk.download('wordnet', download_dir=NLTK_DATA_PATH, quiet=True)
nltk.download('stopwords', download_dir=NLTK_DATA_PATH, quiet=True)

# === Cleaning Function ===
def sahi_karneka_function(x):
    nouns=[]
    li=[]
    lem=wl()
    l=s(x) 
    for i in l:
        d=w(i.lower())
        for k in d:
            li.append(k)
    lw=len(li)
    j=0
    while j<lw:
        if li[j] in string.punctuation:
            li.remove(li[j])
            lw=len(li)
            j=0
        elif li[j] in stp.words("english"):
            li.remove(li[j])
            lw=len(li)
            j=0
        else:
            j=j+1
    tags=pos_tag(li)
    for word,tag in tags:
        if tag.startswith("NN") or tag.startswith("V"):
            nouns.append(word)
    semi_final_words=[lem.lemmatize(m,pos="n") if tagg.startswith("NN") else lem.lemmatize(m,pos="v") for m,tagg in pos_tag(nouns)]
    final_sentence=" ".join(semi_final_words)
    return final_sentence

# === Load Data and Models ===
df = pd.read_csv(r"src/c_d.csv")
model = joblib.load("src/logistic_models.pkl")
tfidf = joblib.load("src/tfidf.pkl")
ml = joblib.load("src/multilabels.pkl")

# === Streamlit UI ===
st.title("๐Ÿง  Multi-Label Question Tag Predictor")

# --- Select a URL for context ---
selected_url = st.selectbox("Select a question URL (for context):", df['questions_url'])
st.markdown(f"๐Ÿ”— [Open selected question]({selected_url})")

# --- Session State ---
if "user_input" not in st.session_state:
    st.session_state["user_input"] = ""
if "clear_input" not in st.session_state:
    st.session_state["clear_input"] = False

# --- Clear input if flagged (AFTER rerun) ---
if st.session_state.clear_input:
    st.session_state.user_input = ""
    st.session_state.clear_input = False

# --- Input box ---
st.text_area("โœ๏ธ Type your question here:", key="user_input", height=150)

# --- Predict button ---
if st.button("Predict Tags"):
    final_question = st.session_state.user_input.strip()

    if not final_question:
        st.warning("โš ๏ธ Please enter a question.")
    else:
        with st.spinner("๐Ÿ” Predicting tags..."):
            # Step 1: Clean input
            cleaned = sahi_karneka_function(final_question)

            # Step 2: TF-IDF
            f=[]
            f.append(cleaned)
            x_tfidf = tfidf.transform(f)

            # Step 3: Predict
            y_probs = model.predict_proba(x_tfidf)
            threshold = 0.55
            y_predd=model.predict(x_tfidf)
            probs_column1 = np.array([i[:, 1] for i in y_probs]).T
            y_pred = (probs_column1 >= threshold).astype(int)

            # Step 4: Decode
            predicted_tags = ml.inverse_transform(y_predd)

            # Step 5: Display results
            st.success("โœ… Predicted Tags:")
            if predicted_tags and predicted_tags[0]:
                for tag in predicted_tags[0]:
                    st.markdown(f"๐Ÿ”น **`{tag}`**")
            else:
                st.info("No tags matched the threshold.")

        # Step 6: Show a "Clear" button
        if st.button("Clear Input"):
            st.session_state.user_input = ""