Spaces:
Running
on
Zero
Running
on
Zero
update to support gradio 4+
Browse files- app.py +24 -7
- requirements.txt +1 -1
- utils/load_model.py +9 -2
- utils/predict.py +10 -3
app.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
|
| 2 |
import io
|
| 3 |
import os
|
| 4 |
debug = False
|
|
@@ -29,7 +29,7 @@ PREPROCESS = lambda x: OWLVIT_PRECESSOR(images=x, return_tensors='pt')
|
|
| 29 |
IMAGES_FOLDER = "data/images"
|
| 30 |
# XCLIP_RESULTS = json.load(open("data/jsons/xclip_org.json", "r"))
|
| 31 |
IMAGE2GT = json.load(open("data/jsons/image2gt.json", 'r'))
|
| 32 |
-
CUB_DESC_EMBEDS = torch.load('data/text_embeddings/cub_200_desc.pt')
|
| 33 |
CUB_IDX2NAME = json.load(open('data/jsons/cub_desc_idx2name.json', 'r'))
|
| 34 |
CUB_IDX2NAME = {int(k): v for k, v in CUB_IDX2NAME.items()}
|
| 35 |
# correct_predictions = [k for k, v in XCLIP_RESULTS.items() if v['prediction']]
|
|
@@ -269,12 +269,20 @@ def update_selected_image(event: gr.SelectData):
|
|
| 269 |
descs = {k: descs[k] for k in ORDERED_PARTS}
|
| 270 |
custom_text = [custom_class_name] + list(descs.values())
|
| 271 |
descriptions = ";\n".join(custom_text)
|
| 272 |
-
textbox = gr.Textbox.update(value=descriptions, lines=12, visible=True, label="XCLIP descriptions", interactive=True, info='Please use ";" to separate the descriptions for each part, and keep the format of {part name}: {descriptions}', show_label=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 273 |
# modified_exp = gr.HTML().update(value="", visible=True)
|
| 274 |
return gt_label, img_base64, xclip_pred_markdown, xclip_exp, current_image, textbox
|
| 275 |
|
| 276 |
def on_edit_button_click_xclip():
|
| 277 |
-
empty_exp = gr.HTML.update(visible=False)
|
|
|
|
| 278 |
|
| 279 |
# Populate the textbox with current descriptions
|
| 280 |
descs = XCLIP_DESC[current_predicted_class.state]
|
|
@@ -282,7 +290,14 @@ def on_edit_button_click_xclip():
|
|
| 282 |
descs = {k: descs[k] for k in ORDERED_PARTS}
|
| 283 |
custom_text = ["class name: custom"] + list(descs.values())
|
| 284 |
descriptions = ";\n".join(custom_text)
|
| 285 |
-
textbox = gr.Textbox.update(value=descriptions, lines=12, visible=True, label="XCLIP descriptions", interactive=True, info='Please use ";" to separate the descriptions for each part, and keep the format of {part name}: {descriptions}', show_label=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
|
| 287 |
return textbox, empty_exp
|
| 288 |
|
|
@@ -350,10 +365,12 @@ def on_predict_button_click_xclip(textbox_input: str):
|
|
| 350 |
custom_pred_markdown = f"""
|
| 351 |
### <span style='color:{custom_color}'> {custom_label} {custom_pred_score:.4f}</span>
|
| 352 |
"""
|
| 353 |
-
textbox = gr.Textbox.update(visible=False)
|
|
|
|
| 354 |
# return textbox, xclip_pred_markdown, xclip_explanation, custom_pred_markdown, modified_explanation
|
| 355 |
|
| 356 |
-
modified_exp = gr.HTML().update(value=modified_explanation, visible=True)
|
|
|
|
| 357 |
return textbox, xclip_pred_markdown, xclip_explanation, custom_pred_markdown, modified_exp
|
| 358 |
|
| 359 |
|
|
|
|
| 1 |
+
|
| 2 |
import io
|
| 3 |
import os
|
| 4 |
debug = False
|
|
|
|
| 29 |
IMAGES_FOLDER = "data/images"
|
| 30 |
# XCLIP_RESULTS = json.load(open("data/jsons/xclip_org.json", "r"))
|
| 31 |
IMAGE2GT = json.load(open("data/jsons/image2gt.json", 'r'))
|
| 32 |
+
CUB_DESC_EMBEDS = torch.load('data/text_embeddings/cub_200_desc.pt').to(DEVICE)
|
| 33 |
CUB_IDX2NAME = json.load(open('data/jsons/cub_desc_idx2name.json', 'r'))
|
| 34 |
CUB_IDX2NAME = {int(k): v for k, v in CUB_IDX2NAME.items()}
|
| 35 |
# correct_predictions = [k for k, v in XCLIP_RESULTS.items() if v['prediction']]
|
|
|
|
| 269 |
descs = {k: descs[k] for k in ORDERED_PARTS}
|
| 270 |
custom_text = [custom_class_name] + list(descs.values())
|
| 271 |
descriptions = ";\n".join(custom_text)
|
| 272 |
+
# textbox = gr.Textbox.update(value=descriptions, lines=12, visible=True, label="XCLIP descriptions", interactive=True, info='Please use ";" to separate the descriptions for each part, and keep the format of {part name}: {descriptions}', show_label=False)
|
| 273 |
+
textbox = gr.Textbox(value=descriptions,
|
| 274 |
+
lines=12,
|
| 275 |
+
visible=True,
|
| 276 |
+
label="XCLIP descriptions",
|
| 277 |
+
interactive=True,
|
| 278 |
+
info='Please use ";" to separate the descriptions for each part, and keep the format of {part name}: {descriptions}',
|
| 279 |
+
show_label=False)
|
| 280 |
# modified_exp = gr.HTML().update(value="", visible=True)
|
| 281 |
return gt_label, img_base64, xclip_pred_markdown, xclip_exp, current_image, textbox
|
| 282 |
|
| 283 |
def on_edit_button_click_xclip():
|
| 284 |
+
# empty_exp = gr.HTML.update(visible=False)
|
| 285 |
+
empty_exp = gr.HTML(visible=False)
|
| 286 |
|
| 287 |
# Populate the textbox with current descriptions
|
| 288 |
descs = XCLIP_DESC[current_predicted_class.state]
|
|
|
|
| 290 |
descs = {k: descs[k] for k in ORDERED_PARTS}
|
| 291 |
custom_text = ["class name: custom"] + list(descs.values())
|
| 292 |
descriptions = ";\n".join(custom_text)
|
| 293 |
+
# textbox = gr.Textbox.update(value=descriptions, lines=12, visible=True, label="XCLIP descriptions", interactive=True, info='Please use ";" to separate the descriptions for each part, and keep the format of {part name}: {descriptions}', show_label=False)
|
| 294 |
+
textbox = gr.Textbox(value=descriptions,
|
| 295 |
+
lines=12,
|
| 296 |
+
visible=True,
|
| 297 |
+
label="XCLIP descriptions",
|
| 298 |
+
interactive=True,
|
| 299 |
+
info='Please use ";" to separate the descriptions for each part, and keep the format of {part name}: {descriptions}',
|
| 300 |
+
show_label=False)
|
| 301 |
|
| 302 |
return textbox, empty_exp
|
| 303 |
|
|
|
|
| 365 |
custom_pred_markdown = f"""
|
| 366 |
### <span style='color:{custom_color}'> {custom_label} {custom_pred_score:.4f}</span>
|
| 367 |
"""
|
| 368 |
+
# textbox = gr.Textbox.update(visible=False)
|
| 369 |
+
textbox = gr.Textbox(visible=False)
|
| 370 |
# return textbox, xclip_pred_markdown, xclip_explanation, custom_pred_markdown, modified_explanation
|
| 371 |
|
| 372 |
+
# modified_exp = gr.HTML().update(value=modified_explanation, visible=True)
|
| 373 |
+
modified_exp = gr.HTML(value=modified_explanation, visible=True)
|
| 374 |
return textbox, xclip_pred_markdown, xclip_explanation, custom_pred_markdown, modified_exp
|
| 375 |
|
| 376 |
|
requirements.txt
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
torch
|
| 2 |
torchvision
|
| 3 |
-
gradio
|
| 4 |
numpy
|
| 5 |
Pillow
|
| 6 |
transformers
|
|
|
|
| 1 |
torch
|
| 2 |
torchvision
|
| 3 |
+
gradio
|
| 4 |
numpy
|
| 5 |
Pillow
|
| 6 |
transformers
|
utils/load_model.py
CHANGED
|
@@ -1,12 +1,19 @@
|
|
| 1 |
|
| 2 |
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
import torch
|
| 5 |
from transformers import OwlViTProcessor, OwlViTForObjectDetection
|
| 6 |
|
| 7 |
from .model import OwlViTForClassification
|
| 8 |
|
| 9 |
-
@
|
| 10 |
def load_xclip(device: str = "cuda:0",
|
| 11 |
n_classes: int = 183,
|
| 12 |
use_teacher_logits: bool = False,
|
|
|
|
| 1 |
|
| 2 |
|
| 3 |
+
try:
|
| 4 |
+
import spaces
|
| 5 |
+
gpu_decorator = spaces.GPU
|
| 6 |
+
except ImportError:
|
| 7 |
+
# Define a no-operation decorator as fallback
|
| 8 |
+
def gpu_decorator(func):
|
| 9 |
+
return func
|
| 10 |
+
|
| 11 |
import torch
|
| 12 |
from transformers import OwlViTProcessor, OwlViTForObjectDetection
|
| 13 |
|
| 14 |
from .model import OwlViTForClassification
|
| 15 |
|
| 16 |
+
@gpu_decorator
|
| 17 |
def load_xclip(device: str = "cuda:0",
|
| 18 |
n_classes: int = 183,
|
| 19 |
use_teacher_logits: bool = False,
|
utils/predict.py
CHANGED
|
@@ -1,4 +1,11 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import PIL
|
| 3 |
import torch
|
| 4 |
|
|
@@ -30,7 +37,7 @@ def encode_descs_xclip(owlvit_det_processor: callable, model: callable, descs: l
|
|
| 30 |
# text_embeds = torch.cat(text_embeds, dim=0)
|
| 31 |
# text_embeds = torch.nn.functional.normalize(text_embeds, dim=-1)
|
| 32 |
# return text_embeds.to(device)
|
| 33 |
-
@
|
| 34 |
def xclip_pred(new_desc: dict,
|
| 35 |
new_part_mask: dict,
|
| 36 |
new_class: str,
|
|
@@ -76,7 +83,7 @@ def xclip_pred(new_desc: dict,
|
|
| 76 |
n_classes = 201
|
| 77 |
query_tokens = owlvit_processor(text=list(new_desc_.values()), padding="max_length", truncation=True, return_tensors="pt").to(device)
|
| 78 |
new_class_embed = model.owlvit.get_text_features(**query_tokens)
|
| 79 |
-
query_embeds = torch.cat([cub_embeds, new_class_embed], dim=0)
|
| 80 |
modified_class_idx = 200
|
| 81 |
else:
|
| 82 |
n_classes = 200
|
|
|
|
| 1 |
+
try:
|
| 2 |
+
import spaces
|
| 3 |
+
gpu_decorator = spaces.GPU
|
| 4 |
+
except ImportError:
|
| 5 |
+
# Define a no-operation decorator as fallback
|
| 6 |
+
def gpu_decorator(func):
|
| 7 |
+
return func
|
| 8 |
+
|
| 9 |
import PIL
|
| 10 |
import torch
|
| 11 |
|
|
|
|
| 37 |
# text_embeds = torch.cat(text_embeds, dim=0)
|
| 38 |
# text_embeds = torch.nn.functional.normalize(text_embeds, dim=-1)
|
| 39 |
# return text_embeds.to(device)
|
| 40 |
+
@gpu_decorator
|
| 41 |
def xclip_pred(new_desc: dict,
|
| 42 |
new_part_mask: dict,
|
| 43 |
new_class: str,
|
|
|
|
| 83 |
n_classes = 201
|
| 84 |
query_tokens = owlvit_processor(text=list(new_desc_.values()), padding="max_length", truncation=True, return_tensors="pt").to(device)
|
| 85 |
new_class_embed = model.owlvit.get_text_features(**query_tokens)
|
| 86 |
+
query_embeds = torch.cat([cub_embeds, new_class_embed], dim=0).to(device)
|
| 87 |
modified_class_idx = 200
|
| 88 |
else:
|
| 89 |
n_classes = 200
|