Spaces:
Sleeping
Sleeping
test
#1
by
Johnny-Z
- opened
app.py
CHANGED
|
@@ -1,16 +1,24 @@
|
|
| 1 |
import streamlit as st
|
| 2 |
import pandas as pd
|
| 3 |
import time
|
| 4 |
-
import json
|
| 5 |
import os
|
| 6 |
import plotly.graph_objects as go
|
| 7 |
|
|
|
|
| 8 |
st.set_page_config(layout="wide")
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
@st.cache_resource
|
| 11 |
-
def load_and_preprocess_data():
|
| 12 |
start_time = time.time()
|
| 13 |
-
df = pd.read_parquet(
|
| 14 |
df = df.sort_values(by='post_id', ascending=False)
|
| 15 |
df["tags"] = df["tags"].apply(lambda x: set(x))
|
| 16 |
df.set_index('post_id', inplace=True)
|
|
@@ -18,40 +26,62 @@ def load_and_preprocess_data():
|
|
| 18 |
sorted_indices = {
|
| 19 |
'Post ID (Descending)': df.index,
|
| 20 |
'Post ID (Ascending)': df.index[::-1],
|
| 21 |
-
'
|
| 22 |
-
'
|
| 23 |
}
|
| 24 |
print(f"Data loaded and preprocessed: {time.time() - start_time:.2f} seconds")
|
| 25 |
return df, sorted_indices
|
| 26 |
|
| 27 |
-
st.title('
|
| 28 |
-
data, sorted_indices = load_and_preprocess_data()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
-
# isdebar
|
| 31 |
-
st.sidebar.header('Filter Options')
|
| 32 |
-
st.sidebar.write('Adjust the filter options to refine the results.')
|
| 33 |
-
score_range = st.sidebar.slider('Select clip score range', min_value=0.0, max_value=10.0, value=(0.0, 10.0), step=0.1, help='Filter images based on their CLIP score range.')
|
| 34 |
-
score_range_v2 = st.sidebar.slider('Select siglip score range', min_value=0.0, max_value=10.0, value=(6.0, 10.0), step=0.1, help='Filter images based on their SigLIP score range.')
|
| 35 |
page_number = st.sidebar.number_input('Page', min_value=1, value=1, step=1, help='Navigate through the pages of filtered results.')
|
| 36 |
-
sort_option = st.sidebar.selectbox('Sort by (slow)', options=['Post ID (Descending)', 'Post ID (Ascending)', '
|
| 37 |
|
| 38 |
# user input
|
| 39 |
-
user_input_tags = st.text_input('Enter tags (space-separated)', help='Filter images based on tags. Use "-" to exclude tags.')
|
| 40 |
selected_tags = set([tag.strip() for tag in user_input_tags.split() if tag.strip() and not tag.strip().startswith('-')])
|
| 41 |
undesired_tags = set([tag[1:] for tag in user_input_tags.split() if tag.startswith('-')])
|
| 42 |
print(f"Selected tags: {selected_tags}, Undesired tags: {undesired_tags}")
|
| 43 |
|
| 44 |
# Function to filter data based on user input
|
| 45 |
-
def filter_data(df, score_range, score_range_v2, selected_tags, sort_option):
|
| 46 |
start_time = time.time()
|
| 47 |
|
| 48 |
filtered_data = df[
|
| 49 |
-
(df['
|
| 50 |
-
(df['
|
| 51 |
-
(df['
|
| 52 |
-
(df['
|
|
|
|
|
|
|
| 53 |
]
|
| 54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
if sort_option != "Post ID (Descending)":
|
| 57 |
sorted_index = sorted_indices[sort_option]
|
|
@@ -66,11 +96,11 @@ def filter_data(df, score_range, score_range_v2, selected_tags, sort_option):
|
|
| 66 |
return filtered_data
|
| 67 |
|
| 68 |
# Filter data
|
| 69 |
-
filtered_data = filter_data(data, score_range, score_range_v2, selected_tags, sort_option)
|
| 70 |
st.sidebar.write(f"Total filtered images: {len(filtered_data)}")
|
| 71 |
|
| 72 |
# Pagination
|
| 73 |
-
items_per_page =
|
| 74 |
start_idx = (page_number - 1) * items_per_page
|
| 75 |
end_idx = start_idx + items_per_page
|
| 76 |
current_data = filtered_data.iloc[start_idx:end_idx]
|
|
@@ -82,11 +112,18 @@ for row in rows:
|
|
| 82 |
cols = st.columns(columns_per_row)
|
| 83 |
for col, (_, row_data) in zip(cols, row.iterrows()):
|
| 84 |
with col:
|
| 85 |
-
|
| 86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
|
| 88 |
def histogram_slider(df, column1, column2):
|
| 89 |
-
sample_data = df.sample(min(
|
| 90 |
|
| 91 |
fig = go.Figure()
|
| 92 |
fig.add_trace(go.Histogram(x=sample_data[column1], nbinsx=50, name=column1, opacity=0.75))
|
|
@@ -104,5 +141,5 @@ def histogram_slider(df, column1, column2):
|
|
| 104 |
|
| 105 |
# histogram
|
| 106 |
start_time = time.time()
|
| 107 |
-
histogram_slider(filtered_data, '
|
| 108 |
print(f"Histogram displayed: {time.time() - start_time:.2f} seconds")
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
import pandas as pd
|
| 3 |
import time
|
|
|
|
| 4 |
import os
|
| 5 |
import plotly.graph_objects as go
|
| 6 |
|
| 7 |
+
|
| 8 |
st.set_page_config(layout="wide")
|
| 9 |
|
| 10 |
+
data_source = st.sidebar.radio("Source", ["Danbooru", "Gelbooru", "Rule 34"], index=0)
|
| 11 |
+
if data_source == "Danbooru":
|
| 12 |
+
parquet_file = os.getenv('PARQUET_FILE1')
|
| 13 |
+
elif data_source == "Gelbooru":
|
| 14 |
+
parquet_file = os.getenv('PARQUET_FILE2')
|
| 15 |
+
elif data_source == "Rule 34":
|
| 16 |
+
parquet_file = os.getenv('PARQUET_FILE3')
|
| 17 |
+
|
| 18 |
@st.cache_resource
|
| 19 |
+
def load_and_preprocess_data(parquet_file):
|
| 20 |
start_time = time.time()
|
| 21 |
+
df = pd.read_parquet(parquet_file)
|
| 22 |
df = df.sort_values(by='post_id', ascending=False)
|
| 23 |
df["tags"] = df["tags"].apply(lambda x: set(x))
|
| 24 |
df.set_index('post_id', inplace=True)
|
|
|
|
| 26 |
sorted_indices = {
|
| 27 |
'Post ID (Descending)': df.index,
|
| 28 |
'Post ID (Ascending)': df.index[::-1],
|
| 29 |
+
'AVA Score': df['ava_score'].sort_values(ascending=False).index,
|
| 30 |
+
'Aesthetic Score': df['aesthetic_score'].sort_values(ascending=False).index,
|
| 31 |
}
|
| 32 |
print(f"Data loaded and preprocessed: {time.time() - start_time:.2f} seconds")
|
| 33 |
return df, sorted_indices
|
| 34 |
|
| 35 |
+
st.title(f'{data_source} Images')
|
| 36 |
+
data, sorted_indices = load_and_preprocess_data(parquet_file)
|
| 37 |
+
|
| 38 |
+
score_range = st.sidebar.slider('Select AVA Score range', min_value=0.0, max_value=10.0, value=(5.5, 10.0), step=0.1, help='Filter images based on their AVA Score range.')
|
| 39 |
+
score_range_v2 = st.sidebar.slider('Select Aesthetic Score range', min_value=0.0, max_value=10.0, value=(9.0, 10.0), step=0.1, help='Filter images based on their Aesthetic Score range.')
|
| 40 |
+
|
| 41 |
+
min_post_id = int(data.index.min()) if not data.empty else 0
|
| 42 |
+
max_post_id = int(data.index.max()) if not data.empty else 100000
|
| 43 |
+
post_id_range = st.sidebar.slider('Select Post ID range',
|
| 44 |
+
min_value=min_post_id,
|
| 45 |
+
max_value=max_post_id,
|
| 46 |
+
value=(min_post_id, max_post_id),
|
| 47 |
+
step=1000,
|
| 48 |
+
help='Filter images based on Post ID range.')
|
| 49 |
+
|
| 50 |
+
available_ratings = sorted(data['rating'].unique().tolist()) if 'rating' in data.columns else ['general']
|
| 51 |
+
|
| 52 |
+
selected_ratings = st.sidebar.multiselect(
|
| 53 |
+
'Select ratings to include',
|
| 54 |
+
options=available_ratings,
|
| 55 |
+
default=[],
|
| 56 |
+
help='Filter images by their rating category'
|
| 57 |
+
)
|
| 58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
page_number = st.sidebar.number_input('Page', min_value=1, value=1, step=1, help='Navigate through the pages of filtered results.')
|
| 60 |
+
sort_option = st.sidebar.selectbox('Sort by (slow)', options=['Post ID (Descending)', 'Post ID (Ascending)', 'AVA Score', 'Aesthetic Score'], index=0, help='Select sorting option for the results.')
|
| 61 |
|
| 62 |
# user input
|
| 63 |
+
user_input_tags = st.text_input('Enter tags (space-separated)', value='', help='Filter images based on tags. Use "-" to exclude tags.')
|
| 64 |
selected_tags = set([tag.strip() for tag in user_input_tags.split() if tag.strip() and not tag.strip().startswith('-')])
|
| 65 |
undesired_tags = set([tag[1:] for tag in user_input_tags.split() if tag.startswith('-')])
|
| 66 |
print(f"Selected tags: {selected_tags}, Undesired tags: {undesired_tags}")
|
| 67 |
|
| 68 |
# Function to filter data based on user input
|
| 69 |
+
def filter_data(df, score_range, score_range_v2, post_id_range, selected_tags, sort_option, selected_ratings):
|
| 70 |
start_time = time.time()
|
| 71 |
|
| 72 |
filtered_data = df[
|
| 73 |
+
(df['ava_score'] >= score_range[0]) &
|
| 74 |
+
(df['ava_score'] <= score_range[1]) &
|
| 75 |
+
(df['aesthetic_score'] >= score_range_v2[0]) &
|
| 76 |
+
(df['aesthetic_score'] <= score_range_v2[1]) &
|
| 77 |
+
(df.index >= post_id_range[0]) &
|
| 78 |
+
(df.index <= post_id_range[1])
|
| 79 |
]
|
| 80 |
+
|
| 81 |
+
if selected_ratings and 'rating' in df.columns:
|
| 82 |
+
filtered_data = filtered_data[filtered_data['rating'].isin(selected_ratings)]
|
| 83 |
+
|
| 84 |
+
print(f"Data filtered based on scores, post ID and ratings: {time.time() - start_time:.2f} seconds")
|
| 85 |
|
| 86 |
if sort_option != "Post ID (Descending)":
|
| 87 |
sorted_index = sorted_indices[sort_option]
|
|
|
|
| 96 |
return filtered_data
|
| 97 |
|
| 98 |
# Filter data
|
| 99 |
+
filtered_data = filter_data(data, score_range, score_range_v2, post_id_range, selected_tags, sort_option, selected_ratings)
|
| 100 |
st.sidebar.write(f"Total filtered images: {len(filtered_data)}")
|
| 101 |
|
| 102 |
# Pagination
|
| 103 |
+
items_per_page = 50
|
| 104 |
start_idx = (page_number - 1) * items_per_page
|
| 105 |
end_idx = start_idx + items_per_page
|
| 106 |
current_data = filtered_data.iloc[start_idx:end_idx]
|
|
|
|
| 112 |
cols = st.columns(columns_per_row)
|
| 113 |
for col, (_, row_data) in zip(cols, row.iterrows()):
|
| 114 |
with col:
|
| 115 |
+
post_id = row_data.name
|
| 116 |
+
if data_source == "Danbooru":
|
| 117 |
+
link = f"https://danbooru.donmai.us/posts/{post_id}"
|
| 118 |
+
elif data_source == "Gelbooru":
|
| 119 |
+
link = f"https://gelbooru.com/index.php?page=post&s=view&id={post_id}"
|
| 120 |
+
elif data_source == "Rule 34":
|
| 121 |
+
link = f"https://rule34.xxx/index.php?page=post&s=view&id={post_id}"
|
| 122 |
+
|
| 123 |
+
st.image(row_data['large_file_url'], caption=f"ID: {row_data.name}, AVA Score: {row_data['ava_score']:.2f}, Aesthetic Score: {row_data['aesthetic_score']:.2f}\n{link}", use_container_width=True)
|
| 124 |
|
| 125 |
def histogram_slider(df, column1, column2):
|
| 126 |
+
sample_data = df.sample(min(10000, len(df)))
|
| 127 |
|
| 128 |
fig = go.Figure()
|
| 129 |
fig.add_trace(go.Histogram(x=sample_data[column1], nbinsx=50, name=column1, opacity=0.75))
|
|
|
|
| 141 |
|
| 142 |
# histogram
|
| 143 |
start_time = time.time()
|
| 144 |
+
histogram_slider(filtered_data, 'ava_score', 'aesthetic_score')
|
| 145 |
print(f"Histogram displayed: {time.time() - start_time:.2f} seconds")
|