Spaces:
Runtime error
Runtime error
Add VQA
Browse files- app_caption.py +1 -1
- app_vqa.py +6 -5
- prismer_model.py +28 -16
app_caption.py
CHANGED
|
@@ -18,7 +18,7 @@ def create_demo():
|
|
| 18 |
model_name = gr.Dropdown(label='Model', choices=['Prismer-Base, Prismer-Large'], value='Prismer-Base')
|
| 19 |
run_button = gr.Button('Run')
|
| 20 |
with gr.Column(scale=1.5):
|
| 21 |
-
caption = gr.Text(label='
|
| 22 |
with gr.Row():
|
| 23 |
depth = gr.Image(label='Depth')
|
| 24 |
edge = gr.Image(label='Edge')
|
|
|
|
| 18 |
model_name = gr.Dropdown(label='Model', choices=['Prismer-Base, Prismer-Large'], value='Prismer-Base')
|
| 19 |
run_button = gr.Button('Run')
|
| 20 |
with gr.Column(scale=1.5):
|
| 21 |
+
caption = gr.Text(label='Model Prediction')
|
| 22 |
with gr.Row():
|
| 23 |
depth = gr.Image(label='Depth')
|
| 24 |
edge = gr.Image(label='Edge')
|
app_vqa.py
CHANGED
|
@@ -16,9 +16,10 @@ def create_demo():
|
|
| 16 |
with gr.Column():
|
| 17 |
image = gr.Image(label='Input', type='filepath')
|
| 18 |
model_name = gr.Dropdown(label='Model', choices=['Prismer-Base', 'Prismer-Large'], value='Prismer-Base')
|
|
|
|
| 19 |
run_button = gr.Button('Run')
|
| 20 |
with gr.Column(scale=1.5):
|
| 21 |
-
|
| 22 |
with gr.Row():
|
| 23 |
depth = gr.Image(label='Depth')
|
| 24 |
edge = gr.Image(label='Edge')
|
|
@@ -28,8 +29,8 @@ def create_demo():
|
|
| 28 |
object_detection = gr.Image(label='Object Detection')
|
| 29 |
ocr = gr.Image(label='OCR Detection')
|
| 30 |
|
| 31 |
-
inputs = [image, model_name]
|
| 32 |
-
outputs = [
|
| 33 |
|
| 34 |
# paths = sorted(pathlib.Path('prismer/images').glob('*'))
|
| 35 |
# examples = [[path.as_posix(), 'prismer_base'] for path in paths]
|
|
@@ -44,9 +45,9 @@ def create_demo():
|
|
| 44 |
gr.Examples(examples=examples,
|
| 45 |
inputs=inputs,
|
| 46 |
outputs=outputs,
|
| 47 |
-
fn=model.
|
| 48 |
|
| 49 |
-
run_button.click(fn=model.
|
| 50 |
|
| 51 |
|
| 52 |
if __name__ == '__main__':
|
|
|
|
| 16 |
with gr.Column():
|
| 17 |
image = gr.Image(label='Input', type='filepath')
|
| 18 |
model_name = gr.Dropdown(label='Model', choices=['Prismer-Base', 'Prismer-Large'], value='Prismer-Base')
|
| 19 |
+
question = gr.Text(label='Question')
|
| 20 |
run_button = gr.Button('Run')
|
| 21 |
with gr.Column(scale=1.5):
|
| 22 |
+
answer = gr.Text(label='Model Prediction')
|
| 23 |
with gr.Row():
|
| 24 |
depth = gr.Image(label='Depth')
|
| 25 |
edge = gr.Image(label='Edge')
|
|
|
|
| 29 |
object_detection = gr.Image(label='Object Detection')
|
| 30 |
ocr = gr.Image(label='OCR Detection')
|
| 31 |
|
| 32 |
+
inputs = [image, model_name, question]
|
| 33 |
+
outputs = [answer, depth, edge, normals, segmentation, object_detection, ocr]
|
| 34 |
|
| 35 |
# paths = sorted(pathlib.Path('prismer/images').glob('*'))
|
| 36 |
# examples = [[path.as_posix(), 'prismer_base'] for path in paths]
|
|
|
|
| 45 |
gr.Examples(examples=examples,
|
| 46 |
inputs=inputs,
|
| 47 |
outputs=outputs,
|
| 48 |
+
fn=model.run_vqa_model)
|
| 49 |
|
| 50 |
+
run_button.click(fn=model.run_vqa_model, inputs=inputs, outputs=outputs)
|
| 51 |
|
| 52 |
|
| 53 |
if __name__ == '__main__':
|
prismer_model.py
CHANGED
|
@@ -16,7 +16,9 @@ submodule_dir = repo_dir / 'prismer'
|
|
| 16 |
sys.path.insert(0, submodule_dir.as_posix())
|
| 17 |
|
| 18 |
from dataset import create_dataset, create_loader
|
|
|
|
| 19 |
from model.prismer_caption import PrismerCaption
|
|
|
|
| 20 |
|
| 21 |
|
| 22 |
def download_models() -> None:
|
|
@@ -73,6 +75,11 @@ class Model:
|
|
| 73 |
if exp_name == self.exp_name:
|
| 74 |
return
|
| 75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
if self.mode == 'caption':
|
| 77 |
config = {
|
| 78 |
'dataset': 'demo',
|
|
@@ -80,13 +87,12 @@ class Model:
|
|
| 80 |
'label_path': 'prismer/helpers/labels',
|
| 81 |
'experts': ['depth', 'normal', 'seg_coco', 'edge', 'obj_detection', 'ocr_detection'],
|
| 82 |
'image_resolution': 480,
|
| 83 |
-
'prismer_model':
|
| 84 |
'freeze': 'freeze_vision',
|
| 85 |
-
'prefix': '
|
| 86 |
}
|
| 87 |
-
|
| 88 |
model = PrismerCaption(config)
|
| 89 |
-
state_dict = torch.load(f'prismer/logging/
|
| 90 |
|
| 91 |
elif self.mode == 'vqa':
|
| 92 |
config = {
|
|
@@ -95,13 +101,12 @@ class Model:
|
|
| 95 |
'label_path': 'prismer/helpers/labels',
|
| 96 |
'experts': ['depth', 'normal', 'seg_coco', 'edge', 'obj_detection', 'ocr_detection'],
|
| 97 |
'image_resolution': 480,
|
| 98 |
-
'prismer_model':
|
| 99 |
'freeze': 'freeze_vision',
|
| 100 |
-
'prefix': 'A picture of',
|
| 101 |
}
|
| 102 |
|
| 103 |
-
model =
|
| 104 |
-
state_dict = torch.load(f'prismer/logging/
|
| 105 |
|
| 106 |
model.load_state_dict(state_dict)
|
| 107 |
model.eval()
|
|
@@ -131,14 +136,21 @@ class Model:
|
|
| 131 |
return caption, *out_paths
|
| 132 |
|
| 133 |
@torch.inference_mode()
|
| 134 |
-
def run_vqa_model(self, exp_name: str) -> str:
|
| 135 |
self.set_model(exp_name)
|
| 136 |
-
_, test_dataset = create_dataset('
|
| 137 |
test_loader = create_loader(test_dataset, batch_size=1, num_workers=4, train=False)
|
| 138 |
experts, _ = next(iter(test_loader))
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
sys.path.insert(0, submodule_dir.as_posix())
|
| 17 |
|
| 18 |
from dataset import create_dataset, create_loader
|
| 19 |
+
from dataset.utils import pre_question
|
| 20 |
from model.prismer_caption import PrismerCaption
|
| 21 |
+
from model.prismer_vqa import PrismerVQA
|
| 22 |
|
| 23 |
|
| 24 |
def download_models() -> None:
|
|
|
|
| 75 |
if exp_name == self.exp_name:
|
| 76 |
return
|
| 77 |
|
| 78 |
+
if self.exp_name == 'Prismer-Base':
|
| 79 |
+
model_name = 'prismer_base'
|
| 80 |
+
elif self.exp_name == 'Prismer-Large':
|
| 81 |
+
model_name = 'prismer_large'
|
| 82 |
+
|
| 83 |
if self.mode == 'caption':
|
| 84 |
config = {
|
| 85 |
'dataset': 'demo',
|
|
|
|
| 87 |
'label_path': 'prismer/helpers/labels',
|
| 88 |
'experts': ['depth', 'normal', 'seg_coco', 'edge', 'obj_detection', 'ocr_detection'],
|
| 89 |
'image_resolution': 480,
|
| 90 |
+
'prismer_model': model_name,
|
| 91 |
'freeze': 'freeze_vision',
|
| 92 |
+
'prefix': '',
|
| 93 |
}
|
|
|
|
| 94 |
model = PrismerCaption(config)
|
| 95 |
+
state_dict = torch.load(f'prismer/logging/pretrain_{model_name}/pytorch_model.bin', map_location='cuda:0')
|
| 96 |
|
| 97 |
elif self.mode == 'vqa':
|
| 98 |
config = {
|
|
|
|
| 101 |
'label_path': 'prismer/helpers/labels',
|
| 102 |
'experts': ['depth', 'normal', 'seg_coco', 'edge', 'obj_detection', 'ocr_detection'],
|
| 103 |
'image_resolution': 480,
|
| 104 |
+
'prismer_model': model_name,
|
| 105 |
'freeze': 'freeze_vision',
|
|
|
|
| 106 |
}
|
| 107 |
|
| 108 |
+
model = PrismerVQA(config)
|
| 109 |
+
state_dict = torch.load(f'prismer/logging/vqa_{model_name}/pytorch_model.bin', map_location='cuda:0')
|
| 110 |
|
| 111 |
model.load_state_dict(state_dict)
|
| 112 |
model.eval()
|
|
|
|
| 136 |
return caption, *out_paths
|
| 137 |
|
| 138 |
@torch.inference_mode()
|
| 139 |
+
def run_vqa_model(self, exp_name: str, question: str) -> str:
|
| 140 |
self.set_model(exp_name)
|
| 141 |
+
_, test_dataset = create_dataset('caption', self.config)
|
| 142 |
test_loader = create_loader(test_dataset, batch_size=1, num_workers=4, train=False)
|
| 143 |
experts, _ = next(iter(test_loader))
|
| 144 |
+
question = pre_question(question)
|
| 145 |
+
answer = self.model(experts, question, train=False, inference='generate')
|
| 146 |
+
answer = self.tokenizer(answer, max_length=30, padding='max_length', return_tensors='pt').input_ids
|
| 147 |
+
answer = answer.to(experts['rgb'].device)[0]
|
| 148 |
+
answer = self.tokenizer.decode(answer, skip_special_tokens=True)
|
| 149 |
+
answer = answer.capitalize() + '.'
|
| 150 |
+
return answer
|
| 151 |
+
|
| 152 |
+
def run_vqa(self, image_path: str, model_name: str, question: str) -> tuple[str | None, ...]:
|
| 153 |
+
out_paths = run_experts(image_path)
|
| 154 |
+
answer = self.run_vqa_model(model_name, question)
|
| 155 |
+
label_prettify(image_path, out_paths)
|
| 156 |
+
return answer, *out_paths
|