devan019 commited on
Commit
c60ae80
·
verified ·
1 Parent(s): 072eadd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -16
app.py CHANGED
@@ -1,23 +1,29 @@
1
- from sentence_transformers import SentenceTransformer, util
2
- from PIL import Image
3
  import gradio as gr
4
- import requests
 
 
 
5
 
6
  def get_image_embedding(image):
7
- image_model = SentenceTransformer('clip-ViT-B-32')
8
- img_emb = image_model.encode(image)
9
- return {"embedding": img_emb.tolist()}
10
 
11
  def get_text_embedding(text):
12
- multilingual_text_model = SentenceTransformer('clip-ViT-B-32-multilingual-v1')
13
- text_emb = multilingual_text_model.encode(text)
14
- print(text_emb)
15
- print(type(text_emb))
16
- print(text_emb.ndim)
17
- return {"embedding": text_emb.tolist()}
 
 
 
18
 
19
- image_embedding = gr.Interface(fn=get_image_embedding, inputs=gr.Image(type="pil"), outputs=gr.JSON(api_name="image-embedding"), title="Image Embedding")
20
- text_embedding = gr.Interface(fn=get_text_embedding, inputs=gr.Textbox(), outputs=gr.JSON(api_name="text-embedding"), title="Text Embedding")
 
 
 
21
 
22
- space = gr.TabbedInterface([image_embedding, text_embedding], ["Image Embedding", "Text Embedding"])
23
- space.launch()
 
1
+ from sentence_transformers import SentenceTransformer
 
2
  import gradio as gr
3
+
4
+ # Load once
5
+ image_model = SentenceTransformer("clip-ViT-B-32")
6
+ text_model = SentenceTransformer("clip-ViT-B-32-multilingual-v1")
7
 
8
  def get_image_embedding(image):
9
+ emb = image_model.encode(image)
10
+ return {"embedding": emb.tolist()}
 
11
 
12
  def get_text_embedding(text):
13
+ emb = text_model.encode(text)
14
+ return {"embedding": emb.tolist()}
15
+
16
+ with gr.Blocks() as demo:
17
+ with gr.Tab("Image Embedding"):
18
+ img_input = gr.Image(type="pil")
19
+ img_output = gr.JSON()
20
+ img_btn = gr.Button("Generate")
21
+ img_btn.click(get_image_embedding, img_input, img_output)
22
 
23
+ with gr.Tab("Text Embedding"):
24
+ text_input = gr.Textbox()
25
+ text_output = gr.JSON()
26
+ text_btn = gr.Button("Generate")
27
+ text_btn.click(get_text_embedding, text_input, text_output)
28
 
29
+ demo.launch()