Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pandas as pd | |
| import pyarrow.parquet as pq | |
| import pyarrow.dataset as ds | |
| import time | |
| import os | |
| import plotly.graph_objects as go | |
| import gc | |
| import numpy as np | |
| from huggingface_hub import login, snapshot_download | |
| st.set_page_config(layout="wide") | |
| os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1") | |
| hf_token = os.getenv("HF_TOKEN") | |
| if not hf_token: | |
| raise ValueError("environment variable HF_TOKEN not found.") | |
| def init_hf_repo(): | |
| login(token=hf_token) | |
| repo_dir = snapshot_download( | |
| 'Johnny-Z/dataset_viewer', | |
| repo_type='dataset', | |
| max_workers=1, | |
| local_dir_use_symlinks=False | |
| ) | |
| return repo_dir | |
| repo_dir = init_hf_repo() | |
| data_source = st.sidebar.radio("Source", ["Danbooru"], index=0) | |
| if data_source == "Danbooru": | |
| parquet_file = os.path.join(repo_dir, 'danbooru_data.parquet') | |
| def load_parquet_metadata(parquet_file): | |
| try: | |
| parquet_dataset = pq.ParquetFile(parquet_file) | |
| metadata = parquet_dataset.metadata | |
| num_rows = metadata.num_rows | |
| sample_df = next(parquet_dataset.iter_batches(batch_size=10)).to_pandas() | |
| if 'post_id' in sample_df.columns: | |
| try: | |
| min_post_id = float('inf') | |
| max_post_id = float('-inf') | |
| for i in range(parquet_dataset.metadata.num_row_groups): | |
| row_group = parquet_dataset.metadata.row_group(i) | |
| for j in range(row_group.num_columns): | |
| col = row_group.column(j) | |
| if col.path_in_schema == 'post_id': | |
| stats = col.statistics | |
| if stats is not None: | |
| min_post_id = min(min_post_id, stats.min) | |
| max_post_id = max(max_post_id, stats.max) | |
| if min_post_id == float('inf') or max_post_id == float('-inf'): | |
| raise ValueError("Invalid post_id range") | |
| except Exception as e: | |
| st.warning(f"Unable to get post_id range from statistics: {str(e)}") | |
| min_post_id = float('inf') | |
| max_post_id = float('-inf') | |
| with pq.ParquetReader(parquet_file) as reader: | |
| first_batch = next(reader.iter_batches(batch_size=1000)) | |
| first_df = first_batch.to_pandas() | |
| batch_min = first_df['post_id'].min() | |
| batch_max = first_df['post_id'].max() | |
| min_post_id = min(min_post_id, batch_min) | |
| max_post_id = max(max_post_id, batch_max) | |
| num_row_groups = reader.num_row_groups | |
| sample_indices = [0, num_row_groups//2, num_row_groups-1] | |
| for idx in sample_indices: | |
| if idx >= 0 and idx < num_row_groups: | |
| batch = reader.read_row_group(idx).to_pandas() | |
| batch_min = batch['post_id'].min() | |
| batch_max = batch['post_id'].max() | |
| min_post_id = min(min_post_id, batch_min) | |
| max_post_id = max(max_post_id, batch_max) | |
| else: | |
| min_post_id = 0 | |
| max_post_id = 100000 | |
| available_ratings = [] | |
| if 'rating' in sample_df.columns: | |
| ratings_set = set() | |
| for i in range(min(3, parquet_dataset.num_row_groups)): | |
| sample = parquet_dataset.read_row_group(i, columns=['rating']).to_pandas() | |
| ratings_set.update(sample['rating'].unique()) | |
| available_ratings = sorted(list(ratings_set)) | |
| else: | |
| available_ratings = ['general'] | |
| print(f"Metadata loaded: {num_rows} rows, post_id range: {min_post_id}-{max_post_id}") | |
| return { | |
| 'num_rows': num_rows, | |
| 'min_post_id': int(min_post_id), | |
| 'max_post_id': int(max_post_id), | |
| 'available_ratings': available_ratings, | |
| 'columns': sample_df.columns.tolist() | |
| } | |
| except Exception as e: | |
| st.error(f"Error loading Parquet metadata: {str(e)}") | |
| return { | |
| 'num_rows': 0, | |
| 'min_post_id': 0, | |
| 'max_post_id': 100000, | |
| 'available_ratings': ['general'], | |
| 'columns': [] | |
| } | |
| def get_filtered_batch(parquet_file, filters, needed_columns, sort_option): | |
| try: | |
| dataset = ds.dataset(parquet_file, format='parquet') | |
| pa_filters = [] | |
| for col, op, val in filters: | |
| if col in ['post_id', 'preference_score']: | |
| if op == '>=': | |
| pa_filters.append(ds.field(col) >= val) | |
| elif op == '<=': | |
| pa_filters.append(ds.field(col) <= val) | |
| elif op == 'in' and len(val) > 0: | |
| rating_filters = [ds.field(col) == r for r in val] | |
| if rating_filters: | |
| or_expr = rating_filters[0] | |
| for rf in rating_filters[1:]: | |
| or_expr = or_expr | rf | |
| pa_filters.append(or_expr) | |
| final_filter = None | |
| if pa_filters: | |
| final_filter = pa_filters[0] | |
| for f in pa_filters[1:]: | |
| final_filter = final_filter & f | |
| scanner = dataset.scanner(columns=needed_columns, filter=final_filter) | |
| df = scanner.to_table().to_pandas() | |
| df.set_index('post_id', inplace=True) | |
| if sort_option == "Post ID (Descending)": | |
| df = df.sort_values(by=df.index.name, ascending=False) | |
| elif sort_option == "Post ID (Ascending)": | |
| df = df.sort_values(by=df.index.name, ascending=True) | |
| elif sort_option == "Preference Score": | |
| df = df.sort_values(by='preference_score', ascending=False) | |
| return df | |
| except Exception as e: | |
| st.error(f"Error reading batch: {str(e)}") | |
| return pd.DataFrame() | |
| def process_tags_for_filtering(df, selected_tags, undesired_tags): | |
| if not selected_tags and not undesired_tags: | |
| return df | |
| mask = np.ones(len(df), dtype=bool) | |
| if selected_tags: | |
| for i, tags_list in enumerate(df['tags']): | |
| if mask[i]: | |
| if isinstance(tags_list, list): | |
| tags_set = set(tags_list) | |
| elif isinstance(tags_list, (np.ndarray, np.generic)): | |
| tags_set = set(tags_list.tolist()) if tags_list.size > 0 else set() | |
| elif tags_list: | |
| tags_set = {tags_list} | |
| else: | |
| tags_set = set() | |
| if not selected_tags.issubset(tags_set): | |
| mask[i] = False | |
| if undesired_tags: | |
| for i, tags_list in enumerate(df['tags']): | |
| if mask[i]: | |
| if isinstance(tags_list, list): | |
| tags_set = set(tags_list) | |
| elif isinstance(tags_list, (np.ndarray, np.generic)): | |
| tags_set = set(tags_list.tolist()) if tags_list.size > 0 else set() | |
| elif tags_list: | |
| tags_set = {tags_list} | |
| else: | |
| tags_set = set() | |
| if undesired_tags.intersection(tags_set): | |
| mask[i] = False | |
| return df[mask] | |
| def get_filtered_data(parquet_file, filters_str, sort_option, selected_tags_str, undesired_tags_str, page_number, items_per_page): | |
| filters = eval(filters_str) | |
| selected_tags = set(eval(selected_tags_str)) | |
| undesired_tags = set(eval(undesired_tags_str)) | |
| needed_columns = ['post_id', 'tags', 'preference_score', 'rating', 'large_file_url'] | |
| df = get_filtered_batch(parquet_file, filters, needed_columns, sort_option) | |
| if selected_tags or undesired_tags: | |
| df = process_tags_for_filtering(df, selected_tags, undesired_tags) | |
| return df | |
| st.title(f'{data_source} Images') | |
| metadata = load_parquet_metadata(parquet_file) | |
| score_range = st.sidebar.slider('Select Preference Score range', min_value=0.0, max_value=10.0, value=(9.0, 10.0), step=0.5) | |
| min_post_id = metadata['min_post_id'] | |
| max_post_id = metadata['max_post_id'] | |
| post_id_range = st.sidebar.slider('Select Post ID range', | |
| min_value=min_post_id, | |
| max_value=max_post_id, | |
| value=(min_post_id, max_post_id), | |
| step=1000) | |
| available_ratings = metadata['available_ratings'] | |
| selected_ratings = st.sidebar.multiselect( | |
| 'Select ratings to include', | |
| options=available_ratings, | |
| default=[], | |
| help='Filter images by their rating category' | |
| ) | |
| page_number = st.sidebar.number_input('Page', min_value=1, value=1, step=1) | |
| items_per_page = 50 | |
| sort_option = st.sidebar.selectbox('Sort by', options=['Post ID (Descending)', 'Post ID (Ascending)', 'Preference Score'], index=0) | |
| user_input_tags = st.text_input('Enter tags (space-separated)', value='', help='Filter images based on tags. Use "-" to exclude tags.') | |
| selected_tags = set([tag.strip() for tag in user_input_tags.split() if tag.strip() and not tag.strip().startswith('-')]) | |
| undesired_tags = set([tag[1:] for tag in user_input_tags.split() if tag.startswith('-')]) | |
| filters = [ | |
| ('preference_score', '>=', score_range[0]), | |
| ('preference_score', '<=', score_range[1]), | |
| ('post_id', '>=', post_id_range[0]), | |
| ('post_id', '<=', post_id_range[1]), | |
| ] | |
| if selected_ratings: | |
| filters.append(('rating', 'in', selected_ratings)) | |
| filters_str = repr(filters) | |
| selected_tags_str = repr(list(selected_tags)) | |
| undesired_tags_str = repr(list(undesired_tags)) | |
| start_time = time.time() | |
| current_batch = get_filtered_data( | |
| parquet_file, filters_str, sort_option, | |
| selected_tags_str, undesired_tags_str, | |
| page_number, items_per_page | |
| ) | |
| print(f"Data retrieved in {time.time() - start_time:.2f} seconds") | |
| batch_start = (page_number - 1) * items_per_page | |
| end_idx = min(batch_start + items_per_page, len(current_batch)) | |
| current_data = current_batch.iloc[batch_start:end_idx] if batch_start < len(current_batch) else pd.DataFrame() | |
| st.sidebar.write(f"Images on this page: {len(current_data)}") | |
| st.sidebar.write(f"Total filtered sample: {len(current_batch)}") | |
| columns_per_row = 5 | |
| rows = [current_data.iloc[i:i + columns_per_row] for i in range(0, len(current_data), columns_per_row)] | |
| for row in rows: | |
| cols = st.columns(columns_per_row) | |
| for col, (_, row_data) in zip(cols, row.iterrows()): | |
| with col: | |
| post_id = row_data.name | |
| if data_source == "Danbooru": | |
| link = f"https://danbooru.donmai.us/posts/{post_id}" | |
| elif data_source == "Gelbooru": | |
| link = f"https://gelbooru.com/index.php?page=post&s=view&id={post_id}" | |
| elif data_source == "Rule 34": | |
| link = f"https://rule34.xxx/index.php?page=post&s=view&id={post_id}" | |
| st.image( | |
| row_data['large_file_url'], | |
| caption=f"ID: {row_data.name}, Preference Score: {row_data['preference_score']:.2f}\n{link}", | |
| width='stretch' | |
| ) | |
| def histogram_slider(df, column1): | |
| if df.empty: | |
| return | |
| sample_size = min(5000, len(df)) | |
| if len(df) > sample_size: | |
| step = len(df) // sample_size | |
| indices = np.arange(0, len(df), step)[:sample_size] | |
| sample_data = df.iloc[indices] | |
| else: | |
| sample_data = df | |
| hist1, bin_edges1 = np.histogram(sample_data[column1].dropna(), bins=30) | |
| fig = go.Figure() | |
| fig.add_trace(go.Bar( | |
| x=(bin_edges1[:-1] + bin_edges1[1:])/2, | |
| y=hist1, | |
| name=column1, | |
| opacity=0.75, | |
| width=(bin_edges1[1]-bin_edges1[0]) | |
| )) | |
| fig.update_layout( | |
| barmode='overlay', | |
| bargap=0.1, | |
| height=200, | |
| margin=dict(l=0, r=0, t=0, b=0), | |
| legend=dict(orientation='h', yanchor='bottom', y=-0.4, xanchor='center', x=0.5), | |
| ) | |
| st.sidebar.plotly_chart(fig, width='stretch', config={'displayModeBar': False}) | |
| del sample_data, hist1, bin_edges1 | |
| gc.collect() | |
| if not current_batch.empty: | |
| start_time = time.time() | |
| histogram_slider(current_batch, 'preference_score') | |
| print(f"Histogram displayed: {time.time() - start_time:.2f} seconds") |