Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -12,8 +12,8 @@ import torch
|
|
| 12 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 13 |
|
| 14 |
# # Load the pre-trained model and processor
|
| 15 |
-
|
| 16 |
-
|
| 17 |
|
| 18 |
orig_clip_model, orig_clip_processor = clip.load("ViT-B/32", device=device, jit=False)
|
| 19 |
|
|
@@ -64,10 +64,19 @@ with open(emb_filename, 'rb') as emb:
|
|
| 64 |
|
| 65 |
def search(search_query):
|
| 66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
with torch.no_grad():
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
|
| 73 |
# Retrieve the description vector
|
|
@@ -136,8 +145,8 @@ with gr.Blocks(css=".caption-text {font-size: 40px !important;}") as demo:
|
|
| 136 |
label_text.blur(fn=set_labels, inputs=label_text, outputs=labels) # parse list if focus is moved elsewhere; ensures that list is fully parsed before classification
|
| 137 |
label_text.submit(fn=set_labels, inputs=label_text, outputs=labels) # parse list if user hits enter; ensures that list is fully parsed before classification
|
| 138 |
get_btn.click(fn=rand_image, outputs=im)
|
| 139 |
-
im.change(
|
| 140 |
-
reclass_btn.click(
|
| 141 |
|
| 142 |
with gr.Tab("Image Captioning"):
|
| 143 |
with gr.Row():
|
|
|
|
| 12 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 13 |
|
| 14 |
# # Load the pre-trained model and processor
|
| 15 |
+
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
| 16 |
+
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
| 17 |
|
| 18 |
orig_clip_model, orig_clip_processor = clip.load("ViT-B/32", device=device, jit=False)
|
| 19 |
|
|
|
|
| 64 |
|
| 65 |
def search(search_query):
|
| 66 |
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
|
| 71 |
with torch.no_grad():
|
| 72 |
+
|
| 73 |
+
# Encode and normalize the description using CLIP (HF CLIP)
|
| 74 |
+
inputs = processor(text=[text], images=None, return_tensors="pt", padding=True)
|
| 75 |
+
text_encoded = model.get_text_features(**inputs)
|
| 76 |
+
|
| 77 |
+
# # Encode and normalize the description using CLIP (original CLIP)
|
| 78 |
+
# text_encoded = orig_clip_model.encode_text(clip.tokenize(search_query))
|
| 79 |
+
# text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
|
| 80 |
|
| 81 |
|
| 82 |
# Retrieve the description vector
|
|
|
|
| 145 |
label_text.blur(fn=set_labels, inputs=label_text, outputs=labels) # parse list if focus is moved elsewhere; ensures that list is fully parsed before classification
|
| 146 |
label_text.submit(fn=set_labels, inputs=label_text, outputs=labels) # parse list if user hits enter; ensures that list is fully parsed before classification
|
| 147 |
get_btn.click(fn=rand_image, outputs=im)
|
| 148 |
+
im.change(predict, inputs=[im, labels], outputs=cf)
|
| 149 |
+
reclass_btn.click(predict, inputs=[im, labels], outputs=cf)
|
| 150 |
|
| 151 |
with gr.Tab("Image Captioning"):
|
| 152 |
with gr.Row():
|