| from functools import partial |
| import json |
| from multiprocessing.pool import ThreadPool as Pool |
| import gradio as gr |
| from utils import * |
|
|
| from clip_retrieval.clip_client import ClipClient |
|
|
|
|
| def text2image_gr(): |
| def clip_api(query_text='', return_n=8, model_name=clip_base, thumbnail="是"): |
| |
| client = ClipClient(url="http://127.0.0.1:1234//knn-service", |
| indice_name="ltr_cover_index", |
| aesthetic_weight=0, |
| num_images=int(return_n)) |
| |
| result = client.query(text=query_text) |
| |
| if not result or len(result) == 0: |
| print("no result found") |
| return None |
| |
| print(f"get result sucessed, num: {len(result)}") |
| |
| cover_urls = [res['cover_url'] for res in result] |
| cover_info = [] |
| for res in result: |
| json_info = {"cover_url": res['cover_url'], |
| "similarity": round(res['similarity'], 6), |
| "docid": res['docids']} |
| cover_info.append(str(json_info)) |
| pool = Pool() |
| new_url2image = partial(url2img, thumbnail=thumbnail) |
| ret_imgs = pool.map(new_url2image, cover_urls) |
| pool.close() |
| pool.join() |
|
|
| new_ret = [] |
| for i in range(len(ret_imgs)): |
| new_ret.append([ret_imgs[i], cover_info[i]]) |
| return new_ret |
|
|
| examples = [ |
| ["cat", 12, clip_base, "是"], |
| ["dog", 12, clip_base, "是"], |
| ["bag", 12, clip_base, "是"], |
| ["a cat is sit on the table", 12, clip_base, "是"] |
| ] |
|
|
| title = "<h1 align='center'>CLIP文到图搜索应用</h1>" |
|
|
| with gr.Blocks() as demo: |
| gr.Markdown(title) |
| gr.Markdown(description) |
| with gr.Row(): |
| with gr.Column(scale=1): |
| with gr.Column(scale=2): |
| text = gr.Textbox(value="cat", label="请填写文本", elem_id=0, interactive=True) |
| num = gr.components.Slider(minimum=0, maximum=50, step=1, value=8, label="返回图片数(可能被过滤部分)", elem_id=2) |
| model = gr.components.Radio(label="模型选择", choices=[clip_base], |
| value=clip_base, elem_id=3) |
| thumbnail = gr.components.Radio(label="是否返回缩略图", choices=[yes, no], |
| value=yes, elem_id=4) |
| btn = gr.Button("搜索", ) |
| with gr.Column(scale=100): |
| out = gr.Gallery(label="检索结果为:", columns=4, height="auto") |
| inputs = [text, num, model, thumbnail] |
| btn.click(fn=clip_api, inputs=inputs, outputs=out) |
| gr.Examples(examples, inputs=inputs) |
| return demo |
|
|
| if __name__ == "__main__": |
| gr.close_all() |
| with gr.TabbedInterface( |
| [text2image_gr()], |
| ["文到图搜索"], |
| ) as demo: |
| demo.launch(server_name='127.0.0.1', share=False) |
|
|