Update app.py
Browse files
app.py
CHANGED
|
@@ -22,11 +22,10 @@ except FileNotFoundError:
|
|
| 22 |
st.stop()
|
| 23 |
|
| 24 |
def tokenize(text):
|
| 25 |
-
# Ensure the text is a string before splitting
|
| 26 |
if isinstance(text, str):
|
| 27 |
return text.split()
|
| 28 |
else:
|
| 29 |
-
return []
|
| 30 |
|
| 31 |
def embed_text(text_series, fasttext_model):
|
| 32 |
embeddings = []
|
|
@@ -40,26 +39,21 @@ def embed_text(text_series, fasttext_model):
|
|
| 40 |
return np.array(embeddings)
|
| 41 |
|
| 42 |
def preprocess_input(query, title, description, url, fasttext_model):
|
| 43 |
-
# Convert None or NaN to an empty string to avoid errors during tokenization
|
| 44 |
query = str(query) if pd.notna(query) else ''
|
| 45 |
title = str(title) if pd.notna(title) else ''
|
| 46 |
description = str(description) if pd.notna(description) else ''
|
| 47 |
url = str(url) if pd.notna(url) else ''
|
| 48 |
|
| 49 |
-
# Embed each text field using FastText
|
| 50 |
query_ft = embed_text(pd.Series([query]), fasttext_model)
|
| 51 |
title_ft = embed_text(pd.Series([title]), fasttext_model)
|
| 52 |
description_ft = embed_text(pd.Series([description]), fasttext_model)
|
| 53 |
url_ft = embed_text(pd.Series([url]), fasttext_model)
|
| 54 |
|
| 55 |
-
# Combine embeddings into a single array
|
| 56 |
combined_features = np.hstack([query_ft, title_ft, description_ft, url_ft])
|
| 57 |
|
| 58 |
-
# Convert combined_features to a DMatrix for XGBoost
|
| 59 |
dmatrix = xgb.DMatrix(combined_features)
|
| 60 |
return dmatrix
|
| 61 |
|
| 62 |
-
# Function to extract title and description from a URL
|
| 63 |
def extract_title_description(url):
|
| 64 |
headers = {
|
| 65 |
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/104.0.5112.81 Safari/537.36'
|
|
@@ -74,11 +68,10 @@ def extract_title_description(url):
|
|
| 74 |
except Exception as e:
|
| 75 |
return 'Error extracting title', 'Error extracting description'
|
| 76 |
|
| 77 |
-
# Function to make predictions
|
| 78 |
def predict(query, title, description, url, fasttext_model):
|
| 79 |
dmatrix = preprocess_input(query, title, description, url, fasttext_model)
|
| 80 |
-
probability = model.predict(dmatrix, validate_features=False)[0]
|
| 81 |
-
binary_prediction = int(probability >= 0.5)
|
| 82 |
return binary_prediction, probability
|
| 83 |
|
| 84 |
# Streamlit interface
|
|
@@ -101,8 +94,6 @@ with tab1:
|
|
| 101 |
binary_result, confidence = predict(query, title, description, url, fasttext_model)
|
| 102 |
st.write(f'Predicted +/-: {binary_result}')
|
| 103 |
st.write(f'Conf.: {confidence:.2%}')
|
| 104 |
-
|
| 105 |
-
# Convert confidence to a percentage and cast to int
|
| 106 |
confidence_percentage = int(confidence * 100)
|
| 107 |
st.progress(confidence_percentage)
|
| 108 |
else:
|
|
@@ -115,8 +106,6 @@ with tab2:
|
|
| 115 |
|
| 116 |
if uploaded_file is not None:
|
| 117 |
df = pd.read_csv(uploaded_file)
|
| 118 |
-
|
| 119 |
-
# Select only the columns necessary for inference
|
| 120 |
required_columns = ['Query', 'Title', 'Description', 'URL']
|
| 121 |
|
| 122 |
if set(required_columns).issubset(df.columns):
|
|
@@ -127,15 +116,12 @@ with tab2:
|
|
| 127 |
predictions.append(binary_result)
|
| 128 |
confidences.append(confidence)
|
| 129 |
|
| 130 |
-
# Add binary predictions and confidence to the DataFrame
|
| 131 |
df['+/-'] = predictions
|
| 132 |
df['Conf.'] = [f"{conf:.2%}" for conf in confidences]
|
| 133 |
|
| 134 |
-
# Reorder the columns to put '+/-' and 'Conf.' at the front
|
| 135 |
cols = ['+/-', 'Conf.'] + [col for col in df.columns if col not in ['+/-', 'Conf.']]
|
| 136 |
df = df[cols]
|
| 137 |
|
| 138 |
-
# Display and allow download of results
|
| 139 |
st.write(df)
|
| 140 |
st.download_button("Download Predictions", df.to_csv(index=False), "predictions.csv")
|
| 141 |
else:
|
|
@@ -149,11 +135,13 @@ with tab3:
|
|
| 149 |
|
| 150 |
if st.button('Scrape A/B'):
|
| 151 |
title_A, description_A = extract_title_description(url)
|
|
|
|
|
|
|
| 152 |
st.write(f'Extracted Title A: {title_A}')
|
| 153 |
st.write(f'Extracted Description A: {description_A}')
|
| 154 |
|
| 155 |
-
title_B = st.text_input('Title B', value=title_A)
|
| 156 |
-
description_B = st.text_area('Description B', value=description_A)
|
| 157 |
|
| 158 |
if st.button('Predict A/B'):
|
| 159 |
if query and url:
|
|
@@ -163,7 +151,6 @@ with tab3:
|
|
| 163 |
st.write(f'Results for A: Predicted +/-: {binary_result_A}, Conf.: {confidence_A:.2%}')
|
| 164 |
st.write(f'Results for B: Predicted +/-: {binary_result_B}, Conf.: {confidence_B:.2%}')
|
| 165 |
|
| 166 |
-
# Determine improvement
|
| 167 |
if binary_result_A == 1 and binary_result_B == 0:
|
| 168 |
st.write("B is worse than A")
|
| 169 |
elif binary_result_A == 0 and binary_result_B == 1:
|
|
|
|
| 22 |
st.stop()
|
| 23 |
|
| 24 |
def tokenize(text):
|
|
|
|
| 25 |
if isinstance(text, str):
|
| 26 |
return text.split()
|
| 27 |
else:
|
| 28 |
+
return []
|
| 29 |
|
| 30 |
def embed_text(text_series, fasttext_model):
|
| 31 |
embeddings = []
|
|
|
|
| 39 |
return np.array(embeddings)
|
| 40 |
|
| 41 |
def preprocess_input(query, title, description, url, fasttext_model):
|
|
|
|
| 42 |
query = str(query) if pd.notna(query) else ''
|
| 43 |
title = str(title) if pd.notna(title) else ''
|
| 44 |
description = str(description) if pd.notna(description) else ''
|
| 45 |
url = str(url) if pd.notna(url) else ''
|
| 46 |
|
|
|
|
| 47 |
query_ft = embed_text(pd.Series([query]), fasttext_model)
|
| 48 |
title_ft = embed_text(pd.Series([title]), fasttext_model)
|
| 49 |
description_ft = embed_text(pd.Series([description]), fasttext_model)
|
| 50 |
url_ft = embed_text(pd.Series([url]), fasttext_model)
|
| 51 |
|
|
|
|
| 52 |
combined_features = np.hstack([query_ft, title_ft, description_ft, url_ft])
|
| 53 |
|
|
|
|
| 54 |
dmatrix = xgb.DMatrix(combined_features)
|
| 55 |
return dmatrix
|
| 56 |
|
|
|
|
| 57 |
def extract_title_description(url):
|
| 58 |
headers = {
|
| 59 |
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/104.0.5112.81 Safari/537.36'
|
|
|
|
| 68 |
except Exception as e:
|
| 69 |
return 'Error extracting title', 'Error extracting description'
|
| 70 |
|
|
|
|
| 71 |
def predict(query, title, description, url, fasttext_model):
|
| 72 |
dmatrix = preprocess_input(query, title, description, url, fasttext_model)
|
| 73 |
+
probability = model.predict(dmatrix, validate_features=False)[0]
|
| 74 |
+
binary_prediction = int(probability >= 0.5)
|
| 75 |
return binary_prediction, probability
|
| 76 |
|
| 77 |
# Streamlit interface
|
|
|
|
| 94 |
binary_result, confidence = predict(query, title, description, url, fasttext_model)
|
| 95 |
st.write(f'Predicted +/-: {binary_result}')
|
| 96 |
st.write(f'Conf.: {confidence:.2%}')
|
|
|
|
|
|
|
| 97 |
confidence_percentage = int(confidence * 100)
|
| 98 |
st.progress(confidence_percentage)
|
| 99 |
else:
|
|
|
|
| 106 |
|
| 107 |
if uploaded_file is not None:
|
| 108 |
df = pd.read_csv(uploaded_file)
|
|
|
|
|
|
|
| 109 |
required_columns = ['Query', 'Title', 'Description', 'URL']
|
| 110 |
|
| 111 |
if set(required_columns).issubset(df.columns):
|
|
|
|
| 116 |
predictions.append(binary_result)
|
| 117 |
confidences.append(confidence)
|
| 118 |
|
|
|
|
| 119 |
df['+/-'] = predictions
|
| 120 |
df['Conf.'] = [f"{conf:.2%}" for conf in confidences]
|
| 121 |
|
|
|
|
| 122 |
cols = ['+/-', 'Conf.'] + [col for col in df.columns if col not in ['+/-', 'Conf.']]
|
| 123 |
df = df[cols]
|
| 124 |
|
|
|
|
| 125 |
st.write(df)
|
| 126 |
st.download_button("Download Predictions", df.to_csv(index=False), "predictions.csv")
|
| 127 |
else:
|
|
|
|
| 135 |
|
| 136 |
if st.button('Scrape A/B'):
|
| 137 |
title_A, description_A = extract_title_description(url)
|
| 138 |
+
st.session_state['title_A'] = title_A
|
| 139 |
+
st.session_state['description_A'] = description_A
|
| 140 |
st.write(f'Extracted Title A: {title_A}')
|
| 141 |
st.write(f'Extracted Description A: {description_A}')
|
| 142 |
|
| 143 |
+
title_B = st.text_input('Title B', value=st.session_state.get('title_A', ''))
|
| 144 |
+
description_B = st.text_area('Description B', value=st.session_state.get('description_A', ''))
|
| 145 |
|
| 146 |
if st.button('Predict A/B'):
|
| 147 |
if query and url:
|
|
|
|
| 151 |
st.write(f'Results for A: Predicted +/-: {binary_result_A}, Conf.: {confidence_A:.2%}')
|
| 152 |
st.write(f'Results for B: Predicted +/-: {binary_result_B}, Conf.: {confidence_B:.2%}')
|
| 153 |
|
|
|
|
| 154 |
if binary_result_A == 1 and binary_result_B == 0:
|
| 155 |
st.write("B is worse than A")
|
| 156 |
elif binary_result_A == 0 and binary_result_B == 1:
|