Upload 3 files
Browse files- image2image.py +86 -0
- text2image.py +78 -0
- utils.py +43 -0
image2image.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import partial
|
| 2 |
+
import json
|
| 3 |
+
from multiprocessing.pool import ThreadPool as Pool
|
| 4 |
+
import gradio as gr
|
| 5 |
+
import PIL
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from utils import *
|
| 8 |
+
|
| 9 |
+
from clip_retrieval.clip_client import ClipClient
|
| 10 |
+
|
| 11 |
+
def image2text_gr():
|
| 12 |
+
def clip_api(query_image=None, return_n=8, model_name=clip_base, thumbnail=yes):
|
| 13 |
+
client = ClipClient(url="http://9.135.121.52:1234//knn-service",
|
| 14 |
+
indice_name="ltr_cover_index",
|
| 15 |
+
aesthetic_weight=0,
|
| 16 |
+
num_images=int(return_n))
|
| 17 |
+
result = client.query(image=query_image)
|
| 18 |
+
|
| 19 |
+
if not result or len(result) == 0:
|
| 20 |
+
print("no result found")
|
| 21 |
+
return None
|
| 22 |
+
|
| 23 |
+
print(f"get result sucessed, num: {len(result)}")
|
| 24 |
+
|
| 25 |
+
cover_urls = [res['cover_url'] for res in result]
|
| 26 |
+
cover_info = []
|
| 27 |
+
for res in result:
|
| 28 |
+
json_info = {"cover_url": res['cover_url'],
|
| 29 |
+
"similarity": round(res['similarity'], 6),
|
| 30 |
+
"docid": res['docids']}
|
| 31 |
+
cover_info.append(str(json_info))
|
| 32 |
+
pool = Pool()
|
| 33 |
+
new_url2image = partial(url2img, thumbnail=thumbnail)
|
| 34 |
+
ret_imgs = pool.map(new_url2image, cover_urls)
|
| 35 |
+
pool.close()
|
| 36 |
+
pool.join()
|
| 37 |
+
|
| 38 |
+
new_ret = []
|
| 39 |
+
for i in range(len(ret_imgs)):
|
| 40 |
+
new_ret.append([ret_imgs[i], cover_info[i]])
|
| 41 |
+
return new_ret
|
| 42 |
+
|
| 43 |
+
examples = [
|
| 44 |
+
["https://xingchen-data.oss-cn-zhangjiakou.aliyuncs.com/coco/2014/test2014/COCO_test2014_000000000069.jpg", 20,
|
| 45 |
+
clip_base, "是"],
|
| 46 |
+
["https://xingchen-data.oss-cn-zhangjiakou.aliyuncs.com/coco/2014/test2014/COCO_test2014_000000000080.jpg", 20,
|
| 47 |
+
clip_base, "是"],
|
| 48 |
+
["https://xingchen-data.oss-cn-zhangjiakou.aliyuncs.com/coco/2014/train2014/COCO_train2014_000000000009.jpg",
|
| 49 |
+
20, clip_base, "是"],
|
| 50 |
+
["https://xingchen-data.oss-cn-zhangjiakou.aliyuncs.com/coco/2014/train2014/COCO_train2014_000000000308.jpg",
|
| 51 |
+
20, clip_base, "是"]
|
| 52 |
+
]
|
| 53 |
+
|
| 54 |
+
title = "<h1 align='center'>CLIP图到图搜索应用</h1>"
|
| 55 |
+
|
| 56 |
+
with gr.Blocks() as demo:
|
| 57 |
+
gr.Markdown(title)
|
| 58 |
+
gr.Markdown(description)
|
| 59 |
+
with gr.Row():
|
| 60 |
+
with gr.Column(scale=1):
|
| 61 |
+
with gr.Column(scale=2):
|
| 62 |
+
img = gr.Textbox(value="https://xingchen-data.oss-cn-zhangjiakou.aliyuncs.com/coco/2014/test2014/COCO_test2014_000000000069.jpg", label="图片地址", elem_id=0, interactive=True)
|
| 63 |
+
num = gr.components.Slider(minimum=0, maximum=50, step=1, value=8, label="返回图片数(可能被过滤部分)", elem_id=2)
|
| 64 |
+
model = gr.components.Radio(label="模型选择", choices=[clip_base],
|
| 65 |
+
value=clip_base, elem_id=3)
|
| 66 |
+
tn = gr.components.Radio(label="是否返回缩略图", choices=[yes, no],
|
| 67 |
+
value=yes, elem_id=4)
|
| 68 |
+
btn = gr.Button("搜索", )
|
| 69 |
+
with gr.Column(scale=100):
|
| 70 |
+
out = gr.Gallery(label="检索结果为:", columns=4, height="auto")
|
| 71 |
+
inputs = [img, num, model, tn]
|
| 72 |
+
btn.click(fn=clip_api, inputs=inputs, outputs=out)
|
| 73 |
+
gr.Examples(examples, inputs=inputs)
|
| 74 |
+
return demo
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
if __name__ == "__main__":
|
| 78 |
+
with gr.TabbedInterface(
|
| 79 |
+
[image2text_gr()],
|
| 80 |
+
["图到图搜索"],
|
| 81 |
+
) as demo:
|
| 82 |
+
demo.launch(
|
| 83 |
+
#enable_queue=True,
|
| 84 |
+
server_name='127.0.0.1',
|
| 85 |
+
share=False
|
| 86 |
+
)
|
text2image.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import partial
|
| 2 |
+
import json
|
| 3 |
+
from multiprocessing.pool import ThreadPool as Pool
|
| 4 |
+
import gradio as gr
|
| 5 |
+
from utils import *
|
| 6 |
+
|
| 7 |
+
from clip_retrieval.clip_client import ClipClient
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def text2image_gr():
|
| 11 |
+
def clip_api(query_text='', return_n=8, model_name=clip_base, thumbnail="是"):
|
| 12 |
+
client = ClipClient(url="http://9.135.121.52:1234//knn-service",
|
| 13 |
+
indice_name="ltr_cover_index",
|
| 14 |
+
aesthetic_weight=0,
|
| 15 |
+
num_images=int(return_n))
|
| 16 |
+
#result = client.query(embedding_input=query_emb)
|
| 17 |
+
result = client.query(text=query_text)
|
| 18 |
+
|
| 19 |
+
if not result or len(result) == 0:
|
| 20 |
+
print("no result found")
|
| 21 |
+
return None
|
| 22 |
+
|
| 23 |
+
print(f"get result sucessed, num: {len(result)}")
|
| 24 |
+
|
| 25 |
+
cover_urls = [res['cover_url'] for res in result]
|
| 26 |
+
cover_info = []
|
| 27 |
+
for res in result:
|
| 28 |
+
json_info = {"cover_url": res['cover_url'],
|
| 29 |
+
"similarity": round(res['similarity'], 6),
|
| 30 |
+
"docid": res['docids']}
|
| 31 |
+
cover_info.append(str(json_info))
|
| 32 |
+
pool = Pool()
|
| 33 |
+
new_url2image = partial(url2img, thumbnail=thumbnail)
|
| 34 |
+
ret_imgs = pool.map(new_url2image, cover_urls)
|
| 35 |
+
pool.close()
|
| 36 |
+
pool.join()
|
| 37 |
+
|
| 38 |
+
new_ret = []
|
| 39 |
+
for i in range(len(ret_imgs)):
|
| 40 |
+
new_ret.append([ret_imgs[i], cover_info[i]])
|
| 41 |
+
return new_ret
|
| 42 |
+
|
| 43 |
+
examples = [
|
| 44 |
+
["cat", 12, clip_base, "是"],
|
| 45 |
+
["dog", 12, clip_base, "是"],
|
| 46 |
+
["bag", 12, clip_base, "是"],
|
| 47 |
+
["a cat is sit on the table", 12, clip_base, "是"]
|
| 48 |
+
]
|
| 49 |
+
|
| 50 |
+
title = "<h1 align='center'>CLIP文到图搜索应用</h1>"
|
| 51 |
+
|
| 52 |
+
with gr.Blocks() as demo:
|
| 53 |
+
gr.Markdown(title)
|
| 54 |
+
gr.Markdown(description)
|
| 55 |
+
with gr.Row():
|
| 56 |
+
with gr.Column(scale=1):
|
| 57 |
+
with gr.Column(scale=2):
|
| 58 |
+
text = gr.Textbox(value="cat", label="请填写文本", elem_id=0, interactive=True)
|
| 59 |
+
num = gr.components.Slider(minimum=0, maximum=50, step=1, value=8, label="返回图片数(可能被过滤部分)", elem_id=2)
|
| 60 |
+
model = gr.components.Radio(label="模型选择", choices=[clip_base],
|
| 61 |
+
value=clip_base, elem_id=3)
|
| 62 |
+
thumbnail = gr.components.Radio(label="是否返回缩略图", choices=[yes, no],
|
| 63 |
+
value=yes, elem_id=4)
|
| 64 |
+
btn = gr.Button("搜索", )
|
| 65 |
+
with gr.Column(scale=100):
|
| 66 |
+
out = gr.Gallery(label="检索结果为:", columns=4, height="auto") #.style(grid=4, height=200)
|
| 67 |
+
inputs = [text, num, model, thumbnail]
|
| 68 |
+
btn.click(fn=clip_api, inputs=inputs, outputs=out)
|
| 69 |
+
gr.Examples(examples, inputs=inputs)
|
| 70 |
+
return demo
|
| 71 |
+
|
| 72 |
+
if __name__ == "__main__":
|
| 73 |
+
gr.close_all()
|
| 74 |
+
with gr.TabbedInterface(
|
| 75 |
+
[text2image_gr()],
|
| 76 |
+
["文到图搜索"],
|
| 77 |
+
) as demo:
|
| 78 |
+
demo.launch(server_name='127.0.0.1', share=False)
|
utils.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from PIL import Image
|
| 3 |
+
from PIL import ImageFile
|
| 4 |
+
import requests
|
| 5 |
+
import base64
|
| 6 |
+
from io import BytesIO
|
| 7 |
+
|
| 8 |
+
clip_base = "CLIP(Base)"
|
| 9 |
+
description = "本项目为CLIP模型的DEMO,可用于图文检索和图像、文本的表征提取,应用于搜索、推荐等应用场景。"
|
| 10 |
+
|
| 11 |
+
yes = "是"
|
| 12 |
+
no = "否"
|
| 13 |
+
|
| 14 |
+
server_ip = os.environ.get("CLIP_SERVER_IP", "9.135.121.52")
|
| 15 |
+
|
| 16 |
+
clip_service_url_d = {
|
| 17 |
+
clip_base: f'http://{server_ip}/knn-service',
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def pil_base64(image, img_format="JPEG"):
|
| 22 |
+
Image.MAX_IMAGE_PIXELS = 1000000000
|
| 23 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
| 24 |
+
img_buffer = BytesIO()
|
| 25 |
+
image.save(img_buffer, format=img_format)
|
| 26 |
+
byte_data = img_buffer.getvalue()
|
| 27 |
+
base64_str = base64.b64encode(byte_data)
|
| 28 |
+
return base64_str.decode("utf-8")
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def url2img(img_url, thumbnail=yes):
|
| 32 |
+
try:
|
| 33 |
+
#print(img_url, thumbnail)
|
| 34 |
+
#image = Image.open(requests.get(img_url, stream=True).raw)
|
| 35 |
+
path = img_url.split("9.22.26.31")[1]
|
| 36 |
+
image = Image.open(path).convert("RGB")
|
| 37 |
+
max_ = max(image.size)
|
| 38 |
+
if max_ > 224 and thumbnail == yes:
|
| 39 |
+
ratio = max_ // 224
|
| 40 |
+
image.thumbnail(size=(image.width // ratio, image.height // ratio))
|
| 41 |
+
return image
|
| 42 |
+
except Exception as e:
|
| 43 |
+
print(e)
|