| import gradio as gr | |
| import numpy as np | |
| import pandas as pd | |
| from concurrent.futures import ThreadPoolExecutor | |
| import os | |
| kaomojis = [ | |
| "0_0", | |
| "(o)_(o)", | |
| "+_+", | |
| "+_-", | |
| "._.", | |
| "<o>_<o>", | |
| "<|>_<|>", | |
| "=_=", | |
| ">_<", | |
| "3_3", | |
| "6_9", | |
| ">_o", | |
| "@_@", | |
| "^_^", | |
| "o_o", | |
| "u_u", | |
| "x_x", | |
| "|_|", | |
| "||_||", | |
| ] | |
| index_file = './caption_index.parquet' | |
| df = pd.read_parquet(index_file) | |
| def process_input(user_input): | |
| user_tags = set(tag.replace(' ', '_') for tag in user_input.split(', ')) | |
| def match_tags(caption, tags): | |
| caption_set = set(caption.split(', ')) | |
| return tags.issubset(caption_set) | |
| def process_chunk(chunk): | |
| chunk = chunk.copy() | |
| chunk['match'] = chunk.index.to_series().apply(lambda x: match_tags(x, user_tags)) | |
| return chunk[chunk['match']] | |
| chunk_size = 100000 | |
| chunks = [df.iloc[i:i + chunk_size] for i in range(0, df.shape[0], chunk_size)] | |
| with ThreadPoolExecutor(max_workers=8) as executor: | |
| results = executor.map(process_chunk, chunks) | |
| filtered_df = pd.concat(results) | |
| def calculate_weight(score): | |
| try: | |
| weight = float(score) - 5 | |
| return max(weight, 0) | |
| except ValueError: | |
| return 0 | |
| filtered_df['weight'] = filtered_df['score'].apply(calculate_weight) | |
| random_seed = np.random.randint(0, 1000000) | |
| np.random.seed(random_seed) | |
| sample_size = min(5, len(filtered_df)) | |
| if sample_size > 0: | |
| weights = filtered_df['weight'].to_numpy() | |
| weights /= weights.sum() | |
| sampled_indices = np.random.choice(filtered_df.index, size=sample_size, p=weights, replace=False) | |
| sampled_df = filtered_df.loc[sampled_indices] | |
| else: | |
| sampled_df = filtered_df | |
| output = [] | |
| for index, row in sampled_df.iterrows(): | |
| tags = index.split(', ') | |
| processed_tags = [tag.replace('_', ' ') if tag not in kaomojis else tag for tag in tags] | |
| processed_tags = [tag.replace("(", "\(").replace(")", "\)") for tag in processed_tags] | |
| processed_caption = ', '.join(processed_tags) | |
| row['name'] = row['name'].replace('danbooru_', 'https://danbooru.donmai.us/posts/') | |
| output.append(f"<a href='{row['name']}' target='_blank'>{row['name']}</a>: {processed_caption}<br>") | |
| return ''.join(output), len(filtered_df) | |
| iface = gr.Interface( | |
| fn=process_input, | |
| inputs=gr.Textbox(label="Input tags separated by ', '"), | |
| outputs=[ | |
| gr.HTML(), | |
| gr.Number(label="Matched Images Count") | |
| ], | |
| title="Prompt Sampling", | |
| flagging_mode='never' | |
| ) | |
| iface.launch() |