Spaces:
Running
Running
File size: 6,665 Bytes
5c4ad21 f1148e7 5c4ad21 291f7ae 5c4ad21 291f7ae 5c4ad21 291f7ae 5c4ad21 291f7ae 5c4ad21 291f7ae 5c4ad21 291f7ae 5c4ad21 291f7ae 5c4ad21 291f7ae 5c4ad21 291f7ae 5c4ad21 291f7ae 5c4ad21 291f7ae 5c4ad21 291f7ae 5c4ad21 291f7ae 5c4ad21 291f7ae 5c4ad21 291f7ae 5c4ad21 291f7ae 5c4ad21 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 | import streamlit as st
import pandas as pd
import time
import os
import plotly.graph_objects as go
st.set_page_config(layout="wide")
data_source = st.sidebar.radio("Source", ["Danbooru", "Gelbooru", "Rule 34"], index=0)
if data_source == "Danbooru":
parquet_file = os.getenv('PARQUET_FILE1')
elif data_source == "Gelbooru":
parquet_file = os.getenv('PARQUET_FILE2')
elif data_source == "Rule 34":
parquet_file = os.getenv('PARQUET_FILE3')
@st.cache_resource
def load_and_preprocess_data(parquet_file):
start_time = time.time()
df = pd.read_parquet(parquet_file)
df = df.sort_values(by='post_id', ascending=False)
df["tags"] = df["tags"].apply(lambda x: set(x))
df.set_index('post_id', inplace=True)
sorted_indices = {
'Post ID (Descending)': df.index,
'Post ID (Ascending)': df.index[::-1],
'AVA Score': df['ava_score'].sort_values(ascending=False).index,
'Aesthetic Score': df['aesthetic_score'].sort_values(ascending=False).index,
}
print(f"Data loaded and preprocessed: {time.time() - start_time:.2f} seconds")
return df, sorted_indices
st.title(f'{data_source} Images')
data, sorted_indices = load_and_preprocess_data(parquet_file)
score_range = st.sidebar.slider('Select AVA Score range', min_value=0.0, max_value=10.0, value=(5.5, 10.0), step=0.1, help='Filter images based on their AVA Score range.')
score_range_v2 = st.sidebar.slider('Select Aesthetic Score range', min_value=0.0, max_value=10.0, value=(9.0, 10.0), step=0.1, help='Filter images based on their Aesthetic Score range.')
min_post_id = int(data.index.min()) if not data.empty else 0
max_post_id = int(data.index.max()) if not data.empty else 100000
post_id_range = st.sidebar.slider('Select Post ID range',
min_value=min_post_id,
max_value=max_post_id,
value=(min_post_id, max_post_id),
step=1000,
help='Filter images based on Post ID range.')
available_ratings = sorted(data['rating'].unique().tolist()) if 'rating' in data.columns else ['general']
selected_ratings = st.sidebar.multiselect(
'Select ratings to include',
options=available_ratings,
default=[],
help='Filter images by their rating category'
)
page_number = st.sidebar.number_input('Page', min_value=1, value=1, step=1, help='Navigate through the pages of filtered results.')
sort_option = st.sidebar.selectbox('Sort by (slow)', options=['Post ID (Descending)', 'Post ID (Ascending)', 'AVA Score', 'Aesthetic Score'], index=0, help='Select sorting option for the results.')
# user input
user_input_tags = st.text_input('Enter tags (space-separated)', value='', help='Filter images based on tags. Use "-" to exclude tags.')
selected_tags = set([tag.strip() for tag in user_input_tags.split() if tag.strip() and not tag.strip().startswith('-')])
undesired_tags = set([tag[1:] for tag in user_input_tags.split() if tag.startswith('-')])
print(f"Selected tags: {selected_tags}, Undesired tags: {undesired_tags}")
# Function to filter data based on user input
def filter_data(df, score_range, score_range_v2, post_id_range, selected_tags, sort_option, selected_ratings):
start_time = time.time()
filtered_data = df[
(df['ava_score'] >= score_range[0]) &
(df['ava_score'] <= score_range[1]) &
(df['aesthetic_score'] >= score_range_v2[0]) &
(df['aesthetic_score'] <= score_range_v2[1]) &
(df.index >= post_id_range[0]) &
(df.index <= post_id_range[1])
]
if selected_ratings and 'rating' in df.columns:
filtered_data = filtered_data[filtered_data['rating'].isin(selected_ratings)]
print(f"Data filtered based on scores, post ID and ratings: {time.time() - start_time:.2f} seconds")
if sort_option != "Post ID (Descending)":
sorted_index = sorted_indices[sort_option]
sorted_index = sorted_index[sorted_index.isin(filtered_data.index)]
filtered_data = filtered_data.loc[sorted_index]
print(f"Applying indcies: {time.time() - start_time:.2f} seconds")
if selected_tags or undesired_tags:
filtered_data = filtered_data[filtered_data['tags'].apply(lambda x: selected_tags.issubset(x) and not undesired_tags.intersection(x))]
print(f"Data filtered: {time.time() - start_time:.2f} seconds")
return filtered_data
# Filter data
filtered_data = filter_data(data, score_range, score_range_v2, post_id_range, selected_tags, sort_option, selected_ratings)
st.sidebar.write(f"Total filtered images: {len(filtered_data)}")
# Pagination
items_per_page = 50
start_idx = (page_number - 1) * items_per_page
end_idx = start_idx + items_per_page
current_data = filtered_data.iloc[start_idx:end_idx]
# Display the data
columns_per_row = 5
rows = [current_data.iloc[i:i + columns_per_row] for i in range(0, len(current_data), columns_per_row)]
for row in rows:
cols = st.columns(columns_per_row)
for col, (_, row_data) in zip(cols, row.iterrows()):
with col:
post_id = row_data.name
if data_source == "Danbooru":
link = f"https://danbooru.donmai.us/posts/{post_id}"
elif data_source == "Gelbooru":
link = f"https://gelbooru.com/index.php?page=post&s=view&id={post_id}"
elif data_source == "Rule 34":
link = f"https://rule34.xxx/index.php?page=post&s=view&id={post_id}"
st.image(row_data['large_file_url'], caption=f"ID: {row_data.name}, AVA Score: {row_data['ava_score']:.2f}, Aesthetic Score: {row_data['aesthetic_score']:.2f}\n{link}", use_container_width=True)
def histogram_slider(df, column1, column2):
sample_data = df.sample(min(10000, len(df)))
fig = go.Figure()
fig.add_trace(go.Histogram(x=sample_data[column1], nbinsx=50, name=column1, opacity=0.75))
fig.add_trace(go.Histogram(x=sample_data[column2], nbinsx=50, name=column2, opacity=0.75))
fig.update_layout(
barmode='overlay',
bargap=0.1,
height=200,
xaxis=dict(showticklabels=True),
yaxis=dict(showticklabels=True),
margin=dict(l=0, r=0, t=0, b=0),
legend=dict(orientation='h', yanchor='bottom', y=-0.4, xanchor='center', x=0.5),
)
st.sidebar.plotly_chart(fig, use_container_width=True, config={'displayModeBar': False})
# histogram
start_time = time.time()
histogram_slider(filtered_data, 'ava_score', 'aesthetic_score')
print(f"Histogram displayed: {time.time() - start_time:.2f} seconds") |