Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from Lex import * | |
| ''' | |
| lex = Lexica(query="man woman fire snow").images() | |
| ''' | |
| from PIL import Image | |
| import imagehash | |
| import requests | |
| from time import sleep | |
| sleep_time = 0.5 | |
| hash_func_name = list(filter(lambda x: x.endswith("hash") and | |
| "hex" not in x ,dir(imagehash))) | |
| hash_func_name = ['average_hash', 'colorhash', 'dhash', 'phash', 'whash', 'crop_resistant_hash',] | |
| def min_dim_to_size(img, size = 512): | |
| h, w = img.size | |
| ratio = size / max(h, w) | |
| h, w = map(lambda x: int(x * ratio), [h, w]) | |
| return ( ratio ,img.resize((h, w)) ) | |
| #ratio_size = 512 | |
| #ratio, img_rs = min_dim_to_size(img, ratio_size) | |
| def image_click(images, evt: gr.SelectData): | |
| img_selected = images[evt.index] | |
| return images[evt.index]['name'] | |
| def swap_gallery(im, images, func_name): | |
| #### name data is_file | |
| #print(images[0].keys()) | |
| if im is None: | |
| return list(map(lambda x: x["name"], images)) | |
| hash_func = getattr(imagehash, func_name) | |
| im_hash = hash_func(Image.fromarray(im)) | |
| t2_list = sorted(images, key = lambda imm: | |
| hash_func(Image.open(imm["name"])) - im_hash, reverse = False) | |
| return list(map(lambda x: x["name"], t2_list)) | |
| def lexica(prompt, limit_size = 128, ratio_size = 256 + 128): | |
| lex = Lexica(query=prompt).images() | |
| lex = lex[:limit_size] | |
| lex = list(map(lambda x: x.replace("full_jpg", "sm2"), lex)) | |
| lex_ = [] | |
| for ele in lex: | |
| try: | |
| im = Image.open( | |
| requests.get(ele, stream = True).raw | |
| ) | |
| lex_.append(im) | |
| except: | |
| print("err") | |
| sleep(sleep_time) | |
| assert lex_ | |
| lex = list(map(lambda x: min_dim_to_size(x, ratio_size)[1], lex_)) | |
| return lex | |
| def enterpix(prompt, limit_size = 100, ratio_size = 256 + 128, use_key = "bigThumbnailUrl"): | |
| resp = requests.post( | |
| url = "https://www.enterpix.app/enterpix/v1/image/prompt-search", | |
| data= { | |
| "length": limit_size, | |
| "platform": "stable-diffusion,midjourney", | |
| "prompt": prompt, | |
| "start": 0 | |
| } | |
| ) | |
| resp = resp.json() | |
| resp = list(map(lambda x: x[use_key], resp["images"])) | |
| lex_ = [] | |
| for ele in resp: | |
| try: | |
| im = Image.open( | |
| requests.get(ele, stream = True).raw | |
| ) | |
| lex_.append(im) | |
| except: | |
| print("err") | |
| sleep(sleep_time) | |
| assert lex_ | |
| resp = list(map(lambda x: min_dim_to_size(x, ratio_size)[1], lex_)) | |
| return resp | |
| def search(prompt, search_name): | |
| if search_name == "lexica": | |
| return lexica(prompt) | |
| else: | |
| return enterpix(prompt) | |
| with gr.Blocks(css="custom.css") as demo: | |
| title = gr.HTML( | |
| """<h1><img src="https://i.imgur.com/52VJ8vS.png" alt="SD"> StableDiffusion Image Search </h1>""", | |
| elem_id="title", | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Row(): | |
| search_func_name = gr.Radio(choices=["lexica", "enterpix"], | |
| value="lexica", label="Search by", elem_id="search_radio") | |
| with gr.Row(): | |
| #inputs = gr.Textbox(label = 'Enter prompt to search Lexica.art') | |
| inputs = gr.Textbox(label="Prompt", show_label=False, lines=1, max_lines=20, min_width = 256, | |
| placeholder="Enter prompt to search", elem_id="prompt") | |
| #gr.Slider(label='Number of images ', minimum = 4, maximum = 20, step = 1, value = 4)] | |
| text_button = gr.Button("Retrieve Images", elem_id="run_button") | |
| i = gr.Image(elem_id="result-image") | |
| gr.Examples( | |
| [ | |
| ["Chinese Traditional Culture", "lexica", "images/AbsoluteReality.png"], | |
| ["trending digital art", "lexica", "images/waifu.png"], | |
| ["beautiful home", "enterpix", "images/Cyberpunk_Anime.png"], | |
| ["interior design of living room", "enterpix", "images/DreamShaper.png"], | |
| ], | |
| inputs = [inputs, search_func_name, i], | |
| label = "Examples" | |
| ) | |
| with gr.Column(): | |
| title = gr.Markdown( | |
| value="### Click on a Image in the gallery to select it, and the grid order will change", | |
| visible=True, | |
| elem_id="selected_model", | |
| ) | |
| order_func_name = gr.Radio(choices=hash_func_name, | |
| value=hash_func_name[0], label="Order by", elem_id="order_radio") | |
| outputs = gr.Gallery(lable='Output gallery', elem_id="gallery",).style(grid=5,height=768 - 128, | |
| allow_preview=False, label = "retrieve Images") | |
| #gr.Dataframe(label='prompts for corresponding images')] | |
| #outputs.select(image_click, outputs, i, _js="(x) => x.splice(0,x.length)") | |
| outputs.select(image_click, outputs, i,) | |
| i.change( | |
| fn=swap_gallery, | |
| inputs=[i, outputs, order_func_name], | |
| outputs=outputs, | |
| queue=False | |
| ) | |
| order_func_name.change( | |
| fn=swap_gallery, | |
| inputs=[i, outputs, order_func_name], | |
| outputs=outputs, | |
| queue=False | |
| ) | |
| #### gr.Textbox().submit().success() | |
| ### lexica | |
| #text_button.click(lexica, inputs=inputs, outputs=outputs) | |
| ### enterpix | |
| #text_button.click(enterpix, inputs=inputs, outputs=outputs) | |
| text_button.click(search, inputs=[inputs, search_func_name], outputs=outputs) | |
| demo.launch("0.0.0.0") | |