File size: 6,665 Bytes
5c4ad21
 
 
f1148e7
5c4ad21
 
291f7ae
5c4ad21
 
291f7ae
 
 
 
 
 
 
 
5c4ad21
291f7ae
5c4ad21
291f7ae
5c4ad21
 
 
 
 
 
 
291f7ae
 
5c4ad21
 
 
 
291f7ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5c4ad21
 
291f7ae
5c4ad21
 
291f7ae
5c4ad21
 
 
 
 
291f7ae
5c4ad21
 
 
291f7ae
 
 
 
 
 
5c4ad21
291f7ae
 
 
 
 
5c4ad21
 
 
 
 
 
 
 
 
 
 
 
 
 
291f7ae
5c4ad21
 
 
291f7ae
5c4ad21
 
 
 
 
 
 
 
 
 
 
291f7ae
 
 
 
 
 
 
 
 
5c4ad21
 
291f7ae
5c4ad21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291f7ae
5c4ad21
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import streamlit as st
import pandas as pd
import time
import os
import plotly.graph_objects as go


st.set_page_config(layout="wide")

data_source = st.sidebar.radio("Source", ["Danbooru", "Gelbooru", "Rule 34"], index=0)
if data_source == "Danbooru":
    parquet_file = os.getenv('PARQUET_FILE1')
elif data_source == "Gelbooru":
    parquet_file = os.getenv('PARQUET_FILE2')
elif data_source == "Rule 34":
    parquet_file = os.getenv('PARQUET_FILE3')

@st.cache_resource
def load_and_preprocess_data(parquet_file):
    start_time = time.time()
    df = pd.read_parquet(parquet_file)
    df = df.sort_values(by='post_id', ascending=False)
    df["tags"] = df["tags"].apply(lambda x: set(x))
    df.set_index('post_id', inplace=True)

    sorted_indices = {
        'Post ID (Descending)': df.index,
        'Post ID (Ascending)': df.index[::-1],
        'AVA Score': df['ava_score'].sort_values(ascending=False).index,
        'Aesthetic Score': df['aesthetic_score'].sort_values(ascending=False).index,
    }
    print(f"Data loaded and preprocessed: {time.time() - start_time:.2f} seconds")
    return df, sorted_indices

st.title(f'{data_source} Images')
data, sorted_indices = load_and_preprocess_data(parquet_file)

score_range = st.sidebar.slider('Select AVA Score range', min_value=0.0, max_value=10.0, value=(5.5, 10.0), step=0.1, help='Filter images based on their AVA Score range.')
score_range_v2 = st.sidebar.slider('Select Aesthetic Score range', min_value=0.0, max_value=10.0, value=(9.0, 10.0), step=0.1, help='Filter images based on their Aesthetic Score range.')

min_post_id = int(data.index.min()) if not data.empty else 0
max_post_id = int(data.index.max()) if not data.empty else 100000
post_id_range = st.sidebar.slider('Select Post ID range', 
                                  min_value=min_post_id, 
                                  max_value=max_post_id, 
                                  value=(min_post_id, max_post_id), 
                                  step=1000, 
                                  help='Filter images based on Post ID range.')

available_ratings = sorted(data['rating'].unique().tolist()) if 'rating' in data.columns else ['general']

selected_ratings = st.sidebar.multiselect(
    'Select ratings to include',
    options=available_ratings,
    default=[],
    help='Filter images by their rating category'
)

page_number = st.sidebar.number_input('Page', min_value=1, value=1, step=1, help='Navigate through the pages of filtered results.')
sort_option = st.sidebar.selectbox('Sort by (slow)', options=['Post ID (Descending)', 'Post ID (Ascending)', 'AVA Score', 'Aesthetic Score'], index=0, help='Select sorting option for the results.')

# user input
user_input_tags = st.text_input('Enter tags (space-separated)', value='', help='Filter images based on tags. Use "-" to exclude tags.')   
selected_tags = set([tag.strip() for tag in user_input_tags.split() if tag.strip() and not tag.strip().startswith('-')])
undesired_tags = set([tag[1:] for tag in  user_input_tags.split() if tag.startswith('-')])
print(f"Selected tags: {selected_tags}, Undesired tags: {undesired_tags}")

# Function to filter data based on user input
def filter_data(df, score_range, score_range_v2, post_id_range, selected_tags, sort_option, selected_ratings):
    start_time = time.time()
    
    filtered_data = df[
        (df['ava_score'] >= score_range[0]) & 
        (df['ava_score'] <= score_range[1]) &
        (df['aesthetic_score'] >= score_range_v2[0]) & 
        (df['aesthetic_score'] <= score_range_v2[1]) &
        (df.index >= post_id_range[0]) &
        (df.index <= post_id_range[1])
    ]
    
    if selected_ratings and 'rating' in df.columns:
        filtered_data = filtered_data[filtered_data['rating'].isin(selected_ratings)]
    
    print(f"Data filtered based on scores, post ID and ratings: {time.time() - start_time:.2f} seconds")
    
    if sort_option != "Post ID (Descending)":
        sorted_index = sorted_indices[sort_option]
        sorted_index = sorted_index[sorted_index.isin(filtered_data.index)]
        filtered_data = filtered_data.loc[sorted_index]
        print(f"Applying indcies: {time.time() - start_time:.2f} seconds")
    
    if selected_tags or undesired_tags:
        filtered_data = filtered_data[filtered_data['tags'].apply(lambda x: selected_tags.issubset(x) and not undesired_tags.intersection(x))]
    
    print(f"Data filtered: {time.time() - start_time:.2f} seconds")
    return filtered_data

# Filter data
filtered_data = filter_data(data, score_range, score_range_v2, post_id_range, selected_tags, sort_option, selected_ratings)
st.sidebar.write(f"Total filtered images: {len(filtered_data)}")

# Pagination 
items_per_page = 50
start_idx = (page_number - 1) * items_per_page
end_idx = start_idx + items_per_page
current_data = filtered_data.iloc[start_idx:end_idx]

# Display the data
columns_per_row = 5
rows = [current_data.iloc[i:i + columns_per_row] for i in range(0, len(current_data), columns_per_row)]
for row in rows:
    cols = st.columns(columns_per_row)
    for col, (_, row_data) in zip(cols, row.iterrows()):
        with col:
            post_id = row_data.name
            if data_source == "Danbooru":
                link = f"https://danbooru.donmai.us/posts/{post_id}"
            elif data_source == "Gelbooru": 
                link = f"https://gelbooru.com/index.php?page=post&s=view&id={post_id}"
            elif data_source == "Rule 34":
                link = f"https://rule34.xxx/index.php?page=post&s=view&id={post_id}"
                
            st.image(row_data['large_file_url'], caption=f"ID: {row_data.name}, AVA Score: {row_data['ava_score']:.2f}, Aesthetic Score: {row_data['aesthetic_score']:.2f}\n{link}", use_container_width=True)
            
def histogram_slider(df, column1, column2):
    sample_data = df.sample(min(10000, len(df)))
    
    fig = go.Figure()
    fig.add_trace(go.Histogram(x=sample_data[column1], nbinsx=50, name=column1, opacity=0.75))
    fig.add_trace(go.Histogram(x=sample_data[column2], nbinsx=50, name=column2, opacity=0.75))
    fig.update_layout(
        barmode='overlay',
        bargap=0.1,
        height=200,
        xaxis=dict(showticklabels=True),
        yaxis=dict(showticklabels=True),
        margin=dict(l=0, r=0, t=0, b=0),
        legend=dict(orientation='h', yanchor='bottom', y=-0.4, xanchor='center', x=0.5),
    )
    st.sidebar.plotly_chart(fig, use_container_width=True, config={'displayModeBar': False})

# histogram
start_time = time.time()
histogram_slider(filtered_data, 'ava_score', 'aesthetic_score')
print(f"Histogram displayed: {time.time() - start_time:.2f} seconds")