Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -9,12 +9,14 @@ import pickle
|
|
| 9 |
import requests
|
| 10 |
import torch
|
| 11 |
|
| 12 |
-
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
-
# Load the pre-trained model and processor
|
| 16 |
-
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
| 17 |
-
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
| 18 |
|
| 19 |
# Load the Unsplash dataset
|
| 20 |
dataset = load_dataset("jamescalam/unsplash-25k-photos", split="train") # all 25K images are in train split
|
|
@@ -28,6 +30,17 @@ def predict(image, labels):
|
|
| 28 |
probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
|
| 29 |
return {k: float(v) for k, v in zip(labels, probs[0])}
|
| 30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
def rand_image():
|
| 32 |
n = dataset.num_rows
|
| 33 |
r = random.randrange(0,n)
|
|
@@ -48,7 +61,6 @@ emb_filename = 'unsplash-25k-photos-embeddings-indexes.pkl'
|
|
| 48 |
with open(emb_filename, 'rb') as emb:
|
| 49 |
id2url, img_names, img_emb = pickle.load(emb)
|
| 50 |
|
| 51 |
-
orig_clip_model, orig_clip_processor = clip.load("ViT-B/32", device=device, jit=False)
|
| 52 |
|
| 53 |
def search(search_query):
|
| 54 |
|
|
@@ -124,8 +136,8 @@ with gr.Blocks() as demo:
|
|
| 124 |
label_text.blur(fn=set_labels, inputs=label_text, outputs=labels) # parse list if focus is moved elsewhere; ensures that list is fully parsed before classification
|
| 125 |
label_text.submit(fn=set_labels, inputs=label_text, outputs=labels) # parse list if user hits enter; ensures that list is fully parsed before classification
|
| 126 |
get_btn.click(fn=rand_image, outputs=im)
|
| 127 |
-
im.change(
|
| 128 |
-
reclass_btn.click(
|
| 129 |
|
| 130 |
with gr.Tab("Image Captioning"):
|
| 131 |
with gr.Row():
|
|
|
|
| 9 |
import requests
|
| 10 |
import torch
|
| 11 |
|
| 12 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 13 |
+
|
| 14 |
+
# # Load the pre-trained model and processor
|
| 15 |
+
# model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
| 16 |
+
# processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
| 17 |
+
|
| 18 |
+
orig_clip_model, orig_clip_processor = clip.load("ViT-B/32", device=device, jit=False)
|
| 19 |
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
# Load the Unsplash dataset
|
| 22 |
dataset = load_dataset("jamescalam/unsplash-25k-photos", split="train") # all 25K images are in train split
|
|
|
|
| 30 |
probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
|
| 31 |
return {k: float(v) for k, v in zip(labels, probs[0])}
|
| 32 |
|
| 33 |
+
|
| 34 |
+
def predict2(image, labels):
|
| 35 |
+
image = orig_clip_processor(img).unsqueeze(0).to(device)
|
| 36 |
+
text = clip.tokenize(labels).to(device)
|
| 37 |
+
with torch.no_grad():
|
| 38 |
+
image_features = orig_clip_model.encode_image(image)
|
| 39 |
+
text_features = orig_clip_model.encode_text(text)
|
| 40 |
+
logits_per_image, logits_per_text = orig_clip_model(image, text)
|
| 41 |
+
probs = logits_per_image.softmax(dim=-1).cpu().numpy()
|
| 42 |
+
return {k: float(v) for k, v in zip(labels, probs[0])}
|
| 43 |
+
|
| 44 |
def rand_image():
|
| 45 |
n = dataset.num_rows
|
| 46 |
r = random.randrange(0,n)
|
|
|
|
| 61 |
with open(emb_filename, 'rb') as emb:
|
| 62 |
id2url, img_names, img_emb = pickle.load(emb)
|
| 63 |
|
|
|
|
| 64 |
|
| 65 |
def search(search_query):
|
| 66 |
|
|
|
|
| 136 |
label_text.blur(fn=set_labels, inputs=label_text, outputs=labels) # parse list if focus is moved elsewhere; ensures that list is fully parsed before classification
|
| 137 |
label_text.submit(fn=set_labels, inputs=label_text, outputs=labels) # parse list if user hits enter; ensures that list is fully parsed before classification
|
| 138 |
get_btn.click(fn=rand_image, outputs=im)
|
| 139 |
+
im.change(predict2, inputs=[im, labels], outputs=cf)
|
| 140 |
+
reclass_btn.click(predict2, inputs=[im, labels], outputs=cf)
|
| 141 |
|
| 142 |
with gr.Tab("Image Captioning"):
|
| 143 |
with gr.Row():
|