Upload 2 files
Browse files- app.py +98 -0
- caption_index.parquet +3 -0
app.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import numpy as np
|
| 3 |
+
import pandas as pd
|
| 4 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
kaomojis = [
|
| 8 |
+
"0_0",
|
| 9 |
+
"(o)_(o)",
|
| 10 |
+
"+_+",
|
| 11 |
+
"+_-",
|
| 12 |
+
"._.",
|
| 13 |
+
"<o>_<o>",
|
| 14 |
+
"<|>_<|>",
|
| 15 |
+
"=_=",
|
| 16 |
+
">_<",
|
| 17 |
+
"3_3",
|
| 18 |
+
"6_9",
|
| 19 |
+
">_o",
|
| 20 |
+
"@_@",
|
| 21 |
+
"^_^",
|
| 22 |
+
"o_o",
|
| 23 |
+
"u_u",
|
| 24 |
+
"x_x",
|
| 25 |
+
"|_|",
|
| 26 |
+
"||_||",
|
| 27 |
+
]
|
| 28 |
+
|
| 29 |
+
index_file = './caption_index.parquet'
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
df = pd.read_parquet(index_file)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def process_input(user_input):
|
| 36 |
+
user_tags = set(tag.replace(' ', '_') for tag in user_input.split(', '))
|
| 37 |
+
|
| 38 |
+
def match_tags(caption, tags):
|
| 39 |
+
caption_set = set(caption.split(', '))
|
| 40 |
+
return tags.issubset(caption_set)
|
| 41 |
+
|
| 42 |
+
def process_chunk(chunk):
|
| 43 |
+
chunk = chunk.copy()
|
| 44 |
+
chunk['match'] = chunk.index.to_series().apply(lambda x: match_tags(x, user_tags))
|
| 45 |
+
return chunk[chunk['match']]
|
| 46 |
+
|
| 47 |
+
chunk_size = 100000
|
| 48 |
+
chunks = [df.iloc[i:i + chunk_size] for i in range(0, df.shape[0], chunk_size)]
|
| 49 |
+
|
| 50 |
+
with ThreadPoolExecutor(max_workers=8) as executor:
|
| 51 |
+
results = executor.map(process_chunk, chunks)
|
| 52 |
+
|
| 53 |
+
filtered_df = pd.concat(results)
|
| 54 |
+
|
| 55 |
+
def calculate_weight(score):
|
| 56 |
+
try:
|
| 57 |
+
weight = float(score) - 5
|
| 58 |
+
return max(weight, 0)
|
| 59 |
+
except ValueError:
|
| 60 |
+
return 0
|
| 61 |
+
|
| 62 |
+
filtered_df['weight'] = filtered_df['score'].apply(calculate_weight)
|
| 63 |
+
|
| 64 |
+
random_seed = np.random.randint(0, 1000000)
|
| 65 |
+
np.random.seed(random_seed)
|
| 66 |
+
|
| 67 |
+
sample_size = min(5, len(filtered_df))
|
| 68 |
+
|
| 69 |
+
if sample_size > 0:
|
| 70 |
+
weights = filtered_df['weight'].to_numpy()
|
| 71 |
+
weights /= weights.sum()
|
| 72 |
+
sampled_indices = np.random.choice(filtered_df.index, size=sample_size, p=weights, replace=False)
|
| 73 |
+
sampled_df = filtered_df.loc[sampled_indices]
|
| 74 |
+
else:
|
| 75 |
+
sampled_df = filtered_df
|
| 76 |
+
|
| 77 |
+
output = []
|
| 78 |
+
for index, row in sampled_df.iterrows():
|
| 79 |
+
tags = index.split(', ')
|
| 80 |
+
processed_tags = [tag.replace('_', ' ') if tag not in kaomojis else tag for tag in tags]
|
| 81 |
+
processed_tags = [tag.replace("(", "\(").replace(")", "\)") for tag in processed_tags]
|
| 82 |
+
processed_caption = ', '.join(processed_tags)
|
| 83 |
+
row['name'] = row['name'].replace('danbooru_', 'https://danbooru.donmai.us/posts/')
|
| 84 |
+
output.append(f"<a href='{row['name']}' target='_blank'>{row['name']}</a>: {processed_caption}<br>")
|
| 85 |
+
|
| 86 |
+
return ''.join(output), len(filtered_df)
|
| 87 |
+
|
| 88 |
+
iface = gr.Interface(
|
| 89 |
+
fn=process_input,
|
| 90 |
+
inputs=gr.Textbox(label="Input tags separated by ', '"),
|
| 91 |
+
outputs=[
|
| 92 |
+
gr.HTML(),
|
| 93 |
+
gr.Number(label="Matched Images Count")
|
| 94 |
+
],
|
| 95 |
+
title="Prompt Sampling",
|
| 96 |
+
flagging_mode='never'
|
| 97 |
+
)
|
| 98 |
+
iface.launch()
|
caption_index.parquet
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f6ce8a0c716655604d747131a76984a35dc6f15487e038242d640367a8df66db
|
| 3 |
+
size 86522444
|