Spaces:
Running
Running
| import streamlit as st | |
| import pandas as pd | |
| import time | |
| import os | |
| import plotly.graph_objects as go | |
| st.set_page_config(layout="wide") | |
| data_source = st.sidebar.radio("Source", ["Danbooru", "Gelbooru", "Rule 34"], index=0) | |
| if data_source == "Danbooru": | |
| parquet_file = os.getenv('PARQUET_FILE1') | |
| elif data_source == "Gelbooru": | |
| parquet_file = os.getenv('PARQUET_FILE2') | |
| elif data_source == "Rule 34": | |
| parquet_file = os.getenv('PARQUET_FILE3') | |
| def load_and_preprocess_data(parquet_file): | |
| start_time = time.time() | |
| df = pd.read_parquet(parquet_file) | |
| df = df.sort_values(by='post_id', ascending=False) | |
| df["tags"] = df["tags"].apply(lambda x: set(x)) | |
| df.set_index('post_id', inplace=True) | |
| sorted_indices = { | |
| 'Post ID (Descending)': df.index, | |
| 'Post ID (Ascending)': df.index[::-1], | |
| 'AVA Score': df['ava_score'].sort_values(ascending=False).index, | |
| 'Aesthetic Score': df['aesthetic_score'].sort_values(ascending=False).index, | |
| } | |
| print(f"Data loaded and preprocessed: {time.time() - start_time:.2f} seconds") | |
| return df, sorted_indices | |
| st.title(f'{data_source} Images') | |
| data, sorted_indices = load_and_preprocess_data(parquet_file) | |
| score_range = st.sidebar.slider('Select AVA Score range', min_value=0.0, max_value=10.0, value=(5.5, 10.0), step=0.1, help='Filter images based on their AVA Score range.') | |
| score_range_v2 = st.sidebar.slider('Select Aesthetic Score range', min_value=0.0, max_value=10.0, value=(9.0, 10.0), step=0.1, help='Filter images based on their Aesthetic Score range.') | |
| min_post_id = int(data.index.min()) if not data.empty else 0 | |
| max_post_id = int(data.index.max()) if not data.empty else 100000 | |
| post_id_range = st.sidebar.slider('Select Post ID range', | |
| min_value=min_post_id, | |
| max_value=max_post_id, | |
| value=(min_post_id, max_post_id), | |
| step=1000, | |
| help='Filter images based on Post ID range.') | |
| available_ratings = sorted(data['rating'].unique().tolist()) if 'rating' in data.columns else ['general'] | |
| selected_ratings = st.sidebar.multiselect( | |
| 'Select ratings to include', | |
| options=available_ratings, | |
| default=[], | |
| help='Filter images by their rating category' | |
| ) | |
| page_number = st.sidebar.number_input('Page', min_value=1, value=1, step=1, help='Navigate through the pages of filtered results.') | |
| sort_option = st.sidebar.selectbox('Sort by (slow)', options=['Post ID (Descending)', 'Post ID (Ascending)', 'AVA Score', 'Aesthetic Score'], index=0, help='Select sorting option for the results.') | |
| # user input | |
| user_input_tags = st.text_input('Enter tags (space-separated)', value='', help='Filter images based on tags. Use "-" to exclude tags.') | |
| selected_tags = set([tag.strip() for tag in user_input_tags.split() if tag.strip() and not tag.strip().startswith('-')]) | |
| undesired_tags = set([tag[1:] for tag in user_input_tags.split() if tag.startswith('-')]) | |
| print(f"Selected tags: {selected_tags}, Undesired tags: {undesired_tags}") | |
| # Function to filter data based on user input | |
| def filter_data(df, score_range, score_range_v2, post_id_range, selected_tags, sort_option, selected_ratings): | |
| start_time = time.time() | |
| filtered_data = df[ | |
| (df['ava_score'] >= score_range[0]) & | |
| (df['ava_score'] <= score_range[1]) & | |
| (df['aesthetic_score'] >= score_range_v2[0]) & | |
| (df['aesthetic_score'] <= score_range_v2[1]) & | |
| (df.index >= post_id_range[0]) & | |
| (df.index <= post_id_range[1]) | |
| ] | |
| if selected_ratings and 'rating' in df.columns: | |
| filtered_data = filtered_data[filtered_data['rating'].isin(selected_ratings)] | |
| print(f"Data filtered based on scores, post ID and ratings: {time.time() - start_time:.2f} seconds") | |
| if sort_option != "Post ID (Descending)": | |
| sorted_index = sorted_indices[sort_option] | |
| sorted_index = sorted_index[sorted_index.isin(filtered_data.index)] | |
| filtered_data = filtered_data.loc[sorted_index] | |
| print(f"Applying indcies: {time.time() - start_time:.2f} seconds") | |
| if selected_tags or undesired_tags: | |
| filtered_data = filtered_data[filtered_data['tags'].apply(lambda x: selected_tags.issubset(x) and not undesired_tags.intersection(x))] | |
| print(f"Data filtered: {time.time() - start_time:.2f} seconds") | |
| return filtered_data | |
| # Filter data | |
| filtered_data = filter_data(data, score_range, score_range_v2, post_id_range, selected_tags, sort_option, selected_ratings) | |
| st.sidebar.write(f"Total filtered images: {len(filtered_data)}") | |
| # Pagination | |
| items_per_page = 50 | |
| start_idx = (page_number - 1) * items_per_page | |
| end_idx = start_idx + items_per_page | |
| current_data = filtered_data.iloc[start_idx:end_idx] | |
| # Display the data | |
| columns_per_row = 5 | |
| rows = [current_data.iloc[i:i + columns_per_row] for i in range(0, len(current_data), columns_per_row)] | |
| for row in rows: | |
| cols = st.columns(columns_per_row) | |
| for col, (_, row_data) in zip(cols, row.iterrows()): | |
| with col: | |
| post_id = row_data.name | |
| if data_source == "Danbooru": | |
| link = f"https://danbooru.donmai.us/posts/{post_id}" | |
| elif data_source == "Gelbooru": | |
| link = f"https://gelbooru.com/index.php?page=post&s=view&id={post_id}" | |
| elif data_source == "Rule 34": | |
| link = f"https://rule34.xxx/index.php?page=post&s=view&id={post_id}" | |
| st.image(row_data['large_file_url'], caption=f"ID: {row_data.name}, AVA Score: {row_data['ava_score']:.2f}, Aesthetic Score: {row_data['aesthetic_score']:.2f}\n{link}", use_container_width=True) | |
| def histogram_slider(df, column1, column2): | |
| sample_data = df.sample(min(10000, len(df))) | |
| fig = go.Figure() | |
| fig.add_trace(go.Histogram(x=sample_data[column1], nbinsx=50, name=column1, opacity=0.75)) | |
| fig.add_trace(go.Histogram(x=sample_data[column2], nbinsx=50, name=column2, opacity=0.75)) | |
| fig.update_layout( | |
| barmode='overlay', | |
| bargap=0.1, | |
| height=200, | |
| xaxis=dict(showticklabels=True), | |
| yaxis=dict(showticklabels=True), | |
| margin=dict(l=0, r=0, t=0, b=0), | |
| legend=dict(orientation='h', yanchor='bottom', y=-0.4, xanchor='center', x=0.5), | |
| ) | |
| st.sidebar.plotly_chart(fig, use_container_width=True, config={'displayModeBar': False}) | |
| # histogram | |
| start_time = time.time() | |
| histogram_slider(filtered_data, 'ava_score', 'aesthetic_score') | |
| print(f"Histogram displayed: {time.time() - start_time:.2f} seconds") |